[bzoj 2141] 排队

题意:给一个长度为n的序列和m个交换操作,输出原始的和每个交换执行后的逆序数(严格大于小于号)(1≤m≤210^3,1≤n≤210^4)

原始序列的逆序数可以用归并排序等方法求解。考虑交换操作对逆序数的影响。设交换i、j(i<j),序列被分为如下几段: | a |i| b |j| c |

|x > t|为集合x中大于t的元素数目,序列为h,则: ans' = ans + |bj > h[i]| - |ib > h[j]| + |b < h[j]| - |b < h[i]|

需要一种数据结构,支持: - 单点修改 - 查询某一区间大于或小于某数的元素个数

可以写树套树、块套树,也可以直接分块,排序,查询一块时二分,时间复杂度

#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
const int MAX_N = 2e4;
int n, sz, h[MAX_N], x[MAX_N];

int merge_sort(int* l, int* r)
{
	if (r-l <= 1)
		return 0;
	int * m = (r-l)/2 + l, ans = merge_sort(l, m) + merge_sort(m, r);
	static int t[MAX_N];
	int* p = l, * q = m, * o = t;
	while (p != m || q != r) {
		if (q == r || p != m && *p <= *q) {
			ans += q-m;
			*o++ = *p++;
		} else
			*o++ = *q++;
	}
	copy(t, o, l);
	return ans;
}

inline void init()
{
	sz = sqrt(n*log2(n));
	for (int i = 0; i < n; i += sz) {
		int j = min(n, i+sz);
		copy(h+i, h+j, x+i);
		sort(x+i, x+j);
	}
}

void modify(int p, int v)
{
	int b = p/sz, st = b*sz, ed = min(n, (b+1)*sz), i = find(x+st, x+ed, h[p]) - x;
	h[p] = v;
	while (i+1 != ed && v > x[i+1]) {
		x[i] = x[i+1];
		++i;
	}
	while (i != st && v < x[i-1]) {
		x[i] = x[i-1];
		--i;
	}
	x[i] = v;
}

// x < y
inline int cmp(int x, int y, int c)
{
	return x != y && x < y == c;
}

int query(int l, int r, int v, int c)
{
	if (l > r)
		return 0;
	int lb = l/sz, rb = r/sz, p = (lb+1)*sz, q = rb*sz, i, ans = 0;
	if (lb == rb) {
		for (i = l; i <= r; ++i)
			ans += cmp(h[i], v, c);
	} else {
		for (i = l; i < p; ++i)
			ans += cmp(h[i], v, c);
		if (c) // <
			for (; i < q; i += sz)
				ans += lower_bound(x+i, x+i+sz, v) - (x+i);
		else
			for (; i < q; i += sz)
				ans += x+i+sz - upper_bound(x+i, x+i+sz, v);
		for (; i <= r; ++i)
			ans += cmp(h[i], v, c);
	}
	return ans;
}

int main()
{
	int m;
	static int t[MAX_N];
	scanf("%d", &n);
	for (int i = 0; i < n; ++i) {
		scanf("%d", &h[i]);
		t[i] = h[i];
	}
	int ans = merge_sort(t, t+n);
	printf("%d\n", ans);
	scanf("%d", &m);
	init();
	while (m--) {
		int i, j;
		scanf("%d %d", &i, &j);
		--i;
		--j;
		int a = h[i], b = h[j];
		if (a != b) {
			if (i > j) {
				swap(i, j);
				swap(a, b);
			}
			// + [i+1, j] > a - [i, j-1] > b - [i+1, j-1] < a + [i+1, j-1] < b
			printf("%d\n", ans += query(i+1, j, a, 0) - query(i, j-1, b, 0)
					- query(i+1, j-1, a, 1) + query(i+1, j-1, b, 1));
			modify(i, b);
			modify(j, a);
		} else
			printf("%d\n", ans);
	}
	return 0;
}