[NOI 2016] 优秀的拆分

题意:如果一个字符串可被拆分为AABB的形式,其中A、B是任意非空字符串,则称该字符串的这种拆分是优秀的。求长度为n的字符串S的所有子串(连续的一段)的所有拆分方式中,优秀拆分的总个数。(对于95%的数据,n≤2000;对于100%的数据,n≤30000) 考场上的朴素算法可以获得85分。我是85分。然而95分大概是NOIP提高组T1难度吧......

以下下标从1开始。

为以位置i结束的AA的数目,为从位置i开始的AA的数目,则

字符串的比较用哈希,直接计算,时间复杂度。记得出题人说,构造了一两组卡unsigned long long自然溢出的哈希的数据。

这个思路不错。怎么优化数组的计算呢?没什么好想法,于是参考了题解:

BZOJ4650: [Noi2016]优秀的拆分 - Claris - 博客园

枚举A的长度i = 1, 2, ..., n/2,令i、2i、...为关键点,则每个长为i的子串覆盖一个且仅一个关键点,AA覆盖两个相邻的关键点。考察覆盖点ji和(j-1)i的AA。查询后缀(j-1)i和后缀ji的最长公共前缀l、前缀(j-1)i和前缀ji的最长公共后缀r,将l、r分别与i取min。若l+r-1≥i,则更新答案。A可在区间(ji-l, ji+r)内移动,所以是区间修改。差分后转为点修改即可。

后缀数组的构造是的。统计AA,循环次数为,每次查询最长公共前后缀是的。总共

UOJ上看到一些又短又快的代码,用哈希来求LCP。理论上是,然而常数不知道小到哪里去了......

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cassert>
using namespace std;
typedef long long ll;
const int MAX_N = 3e4, MAX_D = 15;

char *str;
bool cmp(int i, int j) { return str[i] < str[j]; }

struct Suffix_Array {
	int sa[MAX_N+1], buf[2][MAX_N+2], c[MAX_N+1], h[MAX_D+1][MAX_N+1], *rk, *t;
	
	Suffix_Array(): rk(buf[0]), t(buf[1]) {}
	
	void build(char s[], int n)
	{
		int i, j, l;
		s[++n] = 0;
		str = s;
		for (i = 0; i < n; ++i) sa[i] = i+1;
		sort(sa, sa+n, cmp);
		int m = rk[sa[0]] = 0;
		for (i = 1; i < n; ++i) rk[sa[i]] = m += s[sa[i-1]] != s[sa[i]];
		for (l = 1; ++m < n; l *= 2) {
			for (i = 0; i < l; ++i) t[i] = n-l+i+1;
			for (j = 0; j < n; ++j) if (sa[j] > l) t[i++] = sa[j]-l;
			fill_n(c, m, 0);
			for (i = 1; i <= n; ++i) ++c[rk[i]];
			for (i = 1; i < n; ++i) c[i] += c[i-1];
			for (i = n-1; i >= 0; --i) sa[--c[rk[t[i]]]] = t[i];
			swap(rk, t);
			m = rk[sa[0]] = 0;
			for (i = 1; i < n; ++i)
				rk[sa[i]] = m += t[sa[i-1]] != t[sa[i]] || t[sa[i-1]+l] != t[sa[i]+l];
		}

		for (i = 1, j = 0; i < n; h[0][rk[i]] = j, j -= j>0, ++i)
			for (int p = sa[rk[i]-1]; s[p+j] == s[i+j]; ++j) ;

		for (i = 1, l = 1; 2*l < n; ++i, l *= 2)
			for (j = 2*l; j < n; ++j)
				h[i][j] = min(h[i-1][j], h[i-1][j-l]);
	}
	
	int lcp(int i, int j)
	{
		assert(i != j);
		i = rk[i], j = rk[j];
		if (i > j) swap(i, j); // (i, j]
		int k = 0;
		while ((1 << k+1) < j-i) ++k;
		return min(h[k][j], h[k][i+(1<<k)]);
	}
} SA[2];

inline void add(ll A[], int l, int r)
{
	assert(l <= r);
	++A[l];
	--A[r+1];
}

ll solve(char s[], int n)
{
	static ll f[MAX_N+2], g[MAX_N+2];
	SA[0].build(s, n);
	for (int i = 1; i <= n/2; ++i)
		swap(s[i], s[n-i+1]);
	SA[1].build(s, n);

	fill_n(f+1, n, 0); // right
	fill_n(g+1, n, 0); // left

	for (int i = 1; i <= n/2; ++i)
		for (int j = 2*i; j <= n; j += i) {
			int r = min(SA[0].lcp(j, j-i), i), l = min(SA[1].lcp(n-j+1, n-j+i+1), i), x = j-i-l+1, y = j+r-1;
			if (l+r-1 >= i) {
				add(f, x, y-2*i+1);
				add(g, x+2*i-1, y);
			}
		}
	
	for (int i = 1; i <= n; ++i) {
		f[i] += f[i-1];
		g[i] += g[i-1];
	}

	ll ans = 0;

	for (int i = 1; i < n; ++i)
		ans += g[i]*f[i+1];

	return ans;
}

int main()
{
	int T;
	scanf("%d", &T);
	while (T--) {
		static char s[MAX_N+2];
		scanf("%s", s+1);
		printf("%lld\n", solve(s, strlen(s+1)));
	}
	return 0;
}