[bzoj 3678] wangxz与OJ

维护一个序列, 支持:

  • 插入 在p位置和p+1位置之间插入整数a,a+1,...,b-1,b; 若p=0, 插在最前面
  • 删除 删除a,a+1,...,b-1,b位置的元素
  • 查询 查询p位置的元素

(1≤n≤2*10^4, 1≤m≤2*10^4, 涉及的所有数在int内) 题面中还说: 在任何情况下,保证序列中的元素总数不超过100000. 其实数据并不满足这个限制. 一道简单题, 拖了几天才AC, 主要是有点气...... TAT

一开始没注意插入的是连续的一段整数. 但是询问数很少, 元素总数不超过10^5, 而且全是简单的单点查询, 于是可以对询问分块: 每块开头处理出目前的序列, 每次询问倒着跑, 逆向推断出p位置的元素以前在哪, 直到到达该块的开头, 或者该位置被插入操作覆盖. 时间复杂度 , N是序列中元素总数的上限.

发现插入的是连续一段整数后, 想到可以Splay水之, 每一段一个点, 需要的时候分裂即可. 但是分块好写啊, 就写了分块......

RE......

对拍无果......

把数组逐渐开大, 发现运行时间变长......

几个RE中冒出一个神奇的WA......加一个对元素总数的断言, WA也变成了RE......

然后给lydsy管理员发了封邮件. 虽然不是权限号, 但是也收到了数据. 果不其然, 不满足 "在任何情况下,保证序列中的元素总数不超过100000".

管理员问题目有什么需要改动的地方, 我建议去掉这句话. 虽然答应了, 但几天过去了, 这句话现在还挂在题面上......

好久没写Splay......稀奇古怪的奇葩错误一大堆......

  1. 基本操作的错误. 比如setc(x,d,y), 错误地在x==0时直接返回, 但这个时候仍然得设置fa[y]=x. splay的时候, 这里的需求是转到某个节点下面, 因此该判断的不是父亲是否为空.
  2. 把一个节点分裂成两个的split函数. 参数是传分裂点的值还是rank? 显然应该是前者. 这个问题有幸在静态查错的时候发现了.
  3. 根据新的节点作儿子还是父亲, 我分成了split_1split_2. 其实后者就是前者旋转一下, 我觉得单独写可以减少一点常数......毫无必要, 还写错了.
  4. 我的rot函数并不更新被提上去的那个点的附加信息, 因为直到splay结束, 信息一直在变, 但不会影响其他节点. 修改split_2的实现的时候, 我忘记了这一点.
  5. 逻辑上的错误. 把[l,r]区间提上去, 如果l-1是一个复合节点的最后一个元素, 并不需要分裂; 其他同理. 开始我判断的是a[x]==b[x]. 好蠢......

555......

Splay

const int N = 2e4 + 2, V = 3*N;

struct Splay
{
	int ptr, root, fa[V], ch[V][2], t[V], sz[V], a[V], b[V];
	int new_node(int l, int r)
	{
		++ptr;
		a[ptr] = l, b[ptr] = r, sz[ptr] = r-l+1;
		return ptr;
	}
	void setc(int x, int d, int y)
	{
		if (x) ch[x][d] = y;
		if (y) fa[y] = x, t[y] = d;
	}
	void up(int x)
	{
		sz[x] = sz[ch[x][0]] + b[x] - a[x] + 1 + sz[ch[x][1]];
	}
	void rot(int x)
	{
		int y = fa[x], d = t[x];
		setc(fa[y], t[y], x);
		setc(y, d, ch[x][d^1]);
		setc(x, d^1, y);
		up(y);
		if (y == root) root = x;
	}
	void splay(int x, int o=0)
	{
		assert(x != o);
		int y;
		while ((y = fa[x]) != o)
		{
			if (fa[y] != o) rot(t[x]^t[y] ? x : y);
			rot(x);
		}
		up(x);
	}
	int find(int x, int k)
	{
		if (sz[x] == 1) return x;
		int s = sz[ch[x][0]];
		if (k <= s) return find(ch[x][0], k);
		s += b[x] - a[x] + 1;
		if (k <= s) return x;
		return find(ch[x][1], k-s);
	}
	int build(int* v, int n)
	{
		if (!n) return 0;
		int m = n/2, x = new_node(v[m], v[m]);
		sz[x] = n;
		setc(x, 0, build(v, m));
		setc(x, 1, build(v+m+1, n-1-m));
		return x;
	}
	int split_1(int x, int c)
	{
		int y = new_node(c, b[x]);
		sz[y] += sz[ch[x][1]];
		setc(y, 1, ch[x][1]);
		setc(x, 1, y);
		b[x] = c-1;
		return y;
	}
	int split_2(int x, int c)
	{
		int y = split_1(x, c);
		rot(y), up(y);
		return y;
	}
	int minimum(int x)
	{
		while (ch[x][0]) x = ch[x][0];
		return x;
	}
	void insert(int p, int l, int r)
	{
		int x = find(root, p), y, z = new_node(l, r);
		splay(x);
		if (sz[ch[x][0]]+b[x]-a[x]+1 == p) y = minimum(ch[x][1]), splay(y, x);
		else y = split_1(x, p-sz[ch[x][0]]+a[x]);
		setc(y, 0, z);
		sz[y] += r-l+1, sz[x] += r-l+1;
	}
	void erase(int l, int r)
	{
		int x = find(root, l-1);
		splay(x);
		if (sz[ch[x][0]]+b[x]-a[x]+1 != l-1)
			split_1(x, l-sz[ch[x][0]]+a[x]-1);
		int z = find(root, r+1);
		splay(z, x);
		if (l-1+sz[ch[z][0]] != r)
			z = split_2(z, r-l+1-sz[ch[z][0]]+a[z]);
		ch[z][0] = 0;
		sz[z] -= r-l+1, sz[x] -= r-l+1;
	}
	int ask(int p)
	{
		splay(find(root, p));
		return p - sz[ch[root][0]] + a[root] - 1;
	}
} T;

int v[N];

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	rep (i, 1, n+1) scanf("%d", v+i);
	T.root = T.build(v, n+2);
	while (m--)
	{
		int o, p, a, b;
		scanf("%d", &o);
		if (o == 2)
		{
			scanf("%d", &p);
			printf("%d\n", T.ask(p+1));
		}
		else if (o)
		{
			scanf("%d%d", &a, &b);
			T.erase(a+1, b+1);
		}
		else
		{
			scanf("%d%d%d", &p, &a, &b);
			T.insert(p+1, a, b);
		}
	}
	return 0;
}

分块

const int N = 1e6 + 1, M = 2e4;

int a[2][N], * x = a[0], * y = a[1];
struct Op
{
	int t, p, a, b;
} C[M];

int ask(int i, int j, int k)
{
	while (--j >= i)
	{
		const Op& c = C[j];
		if (c.t == 0)
		{
			if (k > c.p + c.b - c.a) k -= c.b - c.a + 1;
			else if (k >= c.p) return c.a + k - c.p;
		}
		else if (c.t == 1 && k >= c.a)
		{
			k += c.b - c.a + 1;
		}
	}
	return x[k];
}

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	rep (i, 1, n+1) scanf("%d", y+i);
	rep (i, 0, m)
	{
		Op& c = C[i];
		scanf("%d", &c.t);
		if (c.t == 2) scanf("%d", &c.p);
		else if (c.t == 1) scanf("%d%d", &c.a, &c.b);
		else scanf("%d%d%d", &c.p, &c.a, &c.b), ++c.p;
	}
	for (int i = 0, size = sqrt(m); i < m; i += size)
	{
		if (i) rep (j, 1, n+1) y[j] = ask(i-size, i, j);
		swap(x, y);
		rep (j, i, min(m, i + size))
		{
			const Op& c = C[j];
			if (c.t == 2)
				printf("%d\n", ask(i, j, c.p));
			else if (c.t == 1)
				n -= c.b-c.a+1;
			else
				n += c.b-c.a+1;
			if (n > 1e6)
			{
				fprintf(stderr, "%d\n", n);
				exit(0);
			}
		}
	}
	return 0;
}