1 条题解

  • 0
    @ 2023-6-21 20:04:07

    C++ :

    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <cmath>
    #include <ctime>
    using namespace std;
    
    typedef long long s64;
    
    inline int getint()
    {
    	static char c;
    	while ((c = getchar()) < '0' || c > '9');
    
    	int res = c - '0';
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = res * 10 + c - '0';
    	return res;
    }
    
    const int MaxN = 200005;
    const int MaxM = MaxN * 2;
    const int M = 1000000007;
    const int I = 166666668;
    
    int n, m, nQ;
    int val[MaxM];
    
    struct operation
    {
    	int type;
    	int x, y, w, id;
    	operation() {}
    	operation(int p1, int p2, int p3, int p4, int p5)
    		: type(p1), x(p2), y(p3), w(p4), id(p5) {}
    
    	friend inline bool operator<(const operation &lhs, const operation &rhs)
    	{
    		if (lhs.x != rhs.x)
    			return lhs.x < rhs.x;
    		return lhs.id < rhs.id;
    	}
    };
    
    int op_n = 0;
    operation op[MaxN * 9], ops[MaxN * 9];
    
    namespace Treap
    {
    	struct node
    	{
    		node *lc, *rc;
    		int pri, num, cnt;
    		int pos;
    
    		inline int get() const
    		{
    			return this ? num : 0;
    		}
    
    		inline void update()
    		{
    			num = cnt;
    			num += lc ? lc->num : 0;
    			num += rc ? rc->num : 0;
    		}
    	};
    	node pool[MaxM], *tail = pool;
    	node *rt = NULL;
    
    	inline node *alloc()
    	{
    		node *p = tail++;
    		p->num = p->cnt = 1;
    		p->pri = rand();
    		return p;
    	}
    
    	inline void zig(node *&x)
    	{
    		node *y = x->lc;
    		x->lc = y->rc, y->rc = x;
    		y->num = x->num;
    		x->update(), x = y;
    	}
    	inline void zag(node *&x)
    	{
    		node *y = x->rc;
    		x->rc = y->lc, y->lc = x;
    		y->num = x->num;
    		x->update(), x = y;
    	}
    
    	void insert(node *&x, int pos, node *&y)
    	{
    		if (x == NULL)
    		{
    			y = x = alloc();
    			return;
    		}
    
    		++x->num;
    		if (x->lc->get() + x->cnt > pos)
    		{
    			insert(x->lc, pos, y);
    			if (x->lc->pri < x->pri)
    				zig(x);
    		}
    		else
    		{
    			pos -= x->lc->get();
    			pos -= x->cnt;
    			insert(x->rc, pos, y);
    			if (x->rc->pri < x->pri)
    				zag(x);
    		}
    	}
    
    	void remove(node *x, int pos, node *&y)
    	{
    		--x->num;
    		if (pos <= x->lc->get())
    			remove(x->lc, pos, y);
    		else if (pos == x->lc->get() + x->cnt)
    			--x->cnt, y = x;
    		else
    		{
    			pos -= x->lc->get();
    			pos -= x->cnt;
    			remove(x->rc, pos, y);
    		}
    	}
    
    	void find_kth(node *x, int pos, node *&y)
    	{
    		if (pos <= x->lc->get())
    			find_kth(x->lc, pos, y);
    		else if (pos == x->lc->get() + x->cnt)
    			y = x;
    		else
    		{
    			pos -= x->lc->get();
    			pos -= x->cnt;
    			find_kth(x->rc, pos, y);
    		}
    	}
    
    	int dfs_clock = 0;
    
    	void get_label(node *x)
    	{
    		if (x->lc)
    			get_label(x->lc);
    		x->pos = ++dfs_clock;
    		if (x->rc)
    			get_label(x->rc);
    	}
    }
    using Treap::rt;
    
    int opt[MaxM];
    Treap::node *rel_l[MaxM];
    Treap::node *rel_r[MaxM];
    
    int col_n = 0;
    int col[MaxM];
    
    inline void build_treap()
    {
    	m = 0;
    	for (int i = 0; i < n; ++i)
    		Treap::insert(rt, i, rel_l[++m]), opt[m] = 4;
    
    	col_n = n;
    	for (int i = 0; i < nQ; ++i)
    	{
    		opt[++m] = getint();
    
    		if (opt[m] == 1 || opt[m] == 5)
    		{
    			Treap::find_kth(rt, getint(), rel_l[m]);
    			Treap::find_kth(rt, getint(), rel_r[m]);
    		}
    		else if (opt[m] == 2)
    		{
    			Treap::find_kth(rt, getint(), rel_l[m]);
    			val[m] = col[++col_n] = getint();
    		}
    		else if (opt[m] == 3)
    			Treap::remove(rt, getint(), rel_l[m]);
    		else if (opt[m] == 4)
    		{
    			Treap::insert(rt, getint(), rel_l[m]);
    			val[m] = col[++col_n] = getint();
    		}
    	}
    
    	Treap::get_label(rt);
    }
    
    namespace Set
    {
    	struct node
    	{
    		node *lc, *rc, *fa;
    		int key;
    
    		inline void rotate()
    		{
    			node *x = this, *y = fa, *z = y->fa;
    			node *b = x == y->lc ? x->rc : x->lc;
    
    			x->fa = z, y->fa = x;
    			if (b)
    				b->fa = y;
    
    			if (z)
    			{
    				if (z->lc == y)
    					z->lc = x;
    				else
    					z->rc = x;
    			}
    			if (x == y->lc)
    				x->rc = y, y->lc = b;
    			else
    				x->lc = y, y->rc = b;
    		}
    
    		inline void splay()
    		{
    			while (fa)
    			{
    				if (fa->fa)
    				{
    					if ((fa->lc == this) == (fa->fa->lc == fa))
    						fa->rotate();
    					else
    						rotate();
    				}
    				rotate();
    			}
    		}
    	};
    	node pool[MaxM * 2], *tail = pool;
    
    	inline void find(node *&x, int key)
    	{
    		node *p = x;
    		while (p->key != key)
    			p = key < p->key ? p->lc : p->rc;
    		p->splay();
    		x = p;
    	}
    
    	inline void insert(node *&x, int key)
    	{
    		node *p = x, *q = NULL, **r = NULL;
    		while (p != NULL)
    		{
    			q = p;
    			if (key < p->key)
    				r = &p->lc, p = p->lc;
    			else
    				r = &p->rc, p = p->rc;
    		}
    
    		p = tail++;
    		p->key = key, p->fa = q;
    		r ? *r = p : NULL;
    		p->splay(), x = p;
    	}
    
    	inline void remove(node *&x)
    	{
    		node *p = x, *q = x->lc;
    		while (q->rc)
    			q = q->rc;
    		p->lc->fa = NULL;
    		q->splay();
    		q->rc = p->rc;
    		p->rc->fa = q;
    		x = q;
    	}
    
    	inline int query_lower(node *x)
    	{
    		x = x->lc;
    		while (x->rc)
    			x = x->rc;
    		return x->key;
    	}
    	inline int query_upper(node *x)
    	{
    		x = x->rc;
    		while (x->lc)
    			x = x->lc;
    		return x->key;
    	}
    }
    
    Set::node *seq[MaxM];
    int cur[MaxM];
    
    inline void insert(int i, int x, int w)
    {
    	Set::insert(seq[w], x);
    	int prev = Set::query_lower(seq[w]);
    	int next = Set::query_upper(seq[w]);
    
    	op[++op_n] = operation(5, x, prev, col[cur[x] = w], i);
    	if (next <= m)
    	{
    		operation u(6, next, prev, col[cur[next]], i);
    		operation v(5, next, x, col[cur[next]], i);
    		op[++op_n] = u;
    		op[++op_n] = v;
    	}
    }
    inline void remove(int i, int x)
    {
    	int w = cur[x];
    	Set::find(seq[w], x);
    	int prev = Set::query_lower(seq[w]);
    	int next = Set::query_upper(seq[w]);
    	Set::remove(seq[w]);
    
    	op[++op_n] = operation(6, x, prev, col[w], i);
    	if (next <= m)
    	{
    		operation u(6, next, x, col[cur[next]], i);
    		operation v(5, next, prev, col[cur[next]], i);
    		op[++op_n] = u;
    		op[++op_n] = v;
    	}
    }
    
    inline void build_ops()
    {
    	sort(col + 1, col + col_n + 1);
    	col_n = unique(col + 1, col + col_n + 1) - col - 1;
    	for (int i = 1; i <= col_n; ++i)
    	{
    		Set::insert(seq[i], 0);
    		Set::insert(seq[i], m + 1);
    	}
    
    	op_n = 0;
    	for (int i = 1; i <= m; ++i)
    	{
    		if (opt[i] == 1)
    		{
    			int l = rel_l[i]->pos;
    			int r = rel_r[i]->pos;
    			op[++op_n] = operation(1, l - 1, l - 1, 0, i);
    			op[++op_n] = operation(2, r, l - 1, 0, i);
    		}
    		else if (opt[i] == 2)
    		{
    			int x = rel_l[i]->pos;
    			int w = lower_bound(col + 1, col + col_n + 1, val[i]) - col;
    			remove(i, x);
    			insert(i, x, w);
    		}
    		else if (opt[i] == 3)
    		{
    			int x = rel_l[i]->pos;
    			remove(i, x);
    		}
    		else if (opt[i] == 4)
    		{
    			int x = rel_l[i]->pos;
    			int w = lower_bound(col + 1, col + col_n + 1, val[i]) - col;
    			insert(i, x, w);
    		}
    		else if (opt[i] == 5)
    		{
    			int l = rel_l[i]->pos;
    			int r = rel_r[i]->pos;
    			op[++op_n] = operation(3, l - 1, l - 1, 0, i);
    			op[++op_n] = operation(4, r, l - 1, 0, i);
    		}
    	}
    }
    
    int bit0[MaxM], bit1[MaxM];
    int bit2[MaxM], bit3[MaxM];
    int res1[MaxM], res2[MaxM], res3[MaxM];
    
    inline void bit_clear(int x)
    {
    	for (int i = ++x; i <= m + 1 && bit0[i]; i += i & -i)
    		bit0[i] = 0;
    	for (int i = x; i <= m + 1 && bit1[i]; i += i & -i)
    		bit1[i] = bit2[i] = bit3[i] = 0;
    }
    
    inline void bit0_add(int x)
    {
    	for (int i = ++x; i <= m + 1; i += i & -i)
    		++bit0[i];
    }
    inline void bit0_del(int x)
    {
    	for (int i = ++x; i <= m + 1; i += i & -i)
    		--bit0[i];
    }
    inline void bit0_query1(int x, int id)
    {
    	for (int i = ++x; i; i ^= i & -i)
    		res1[id] += bit0[i];
    }
    inline void bit0_query2(int x, int id)
    {
    	for (int i = ++x; i; i ^= i & -i)
    		res1[id] -= bit0[i];
    }
    
    inline void bit1_add(int x, int d1)
    {
    	int d2 = (s64)d1 * d1 % M;
    	int d3 = (s64)d2 * d1 % M;
    	for (int i = ++x; i <= m + 1; i += i & -i)
    	{
    		if ((bit1[i] += d1) >= M)
    			bit1[i] -= M;
    		if ((bit2[i] += d2) >= M)
    			bit2[i] -= M;
    		if ((bit3[i] += d3) >= M)
    			bit3[i] -= M;
    	}
    }
    inline void bit1_del(int x, int d1)
    {
    	int d2 = (s64)d1 * d1 % M;
    	int d3 = (s64)d2 * d1 % M;
    	for (int i = ++x; i <= m + 1; i += i & -i)
    	{
    		if ((bit1[i] -= d1) < 0)
    			bit1[i] += M;
    		if ((bit2[i] -= d2) < 0)
    			bit2[i] += M;
    		if ((bit3[i] -= d3) < 0)
    			bit3[i] += M;
    	}
    }
    inline void bit1_query1(int x, int id)
    {
    	for (int i = ++x; i; i ^= i & -i)
    	{
    		if ((res1[id] += bit1[i]) >= M)
    			res1[id] -= M;
    		if ((res2[id] += bit2[i]) >= M)
    			res2[id] -= M;
    		if ((res3[id] += bit3[i]) >= M)
    			res3[id] -= M;
    	}
    }
    inline void bit1_query2(int x, int id)
    {
    	for (int i = ++x; i; i ^= i & -i)
    	{
    		if ((res1[id] -= bit1[i]) < 0)
    			res1[id] += M;
    		if ((res2[id] -= bit2[i]) < 0)
    			res2[id] += M;
    		if ((res3[id] -= bit3[i]) < 0)
    			res3[id] += M;
    	}
    }
    
    void solve(int l, int r)
    {
    	if (l >= r)
    		return;
    	int mid = l + r >> 1;
    	solve(l, mid);
    	solve(mid + 1, r);
    
    	int last1 = mid, last2 = mid;
    	for (int t = mid + 1; t <= r; ++t)
    	{
    		if (op[t].type == 1 || op[t].type == 2)
    			last1 = t;
    		if (op[t].type == 3 || op[t].type == 4)
    			last2 = t;
    	}
    
    	int i = l, j = mid + 1;
    	for (int t = l; t <= r; ++t)
    	{
    		if (j > r || (i <= mid && op[i] < op[j]))
    		{
    			ops[t] = op[i++];
    
    			if (ops[t].type == 5)
    			{
    				int x = ops[t].y;
    				if (j <= last2)
    					bit0_add(x);
    				if (j <= last1)
    					bit1_add(x, ops[t].w);
    			}
    			else if (ops[t].type == 6)
    			{
    				int x = ops[t].y;
    				if (j <= last2)
    					bit0_del(x);
    				if (j <= last1)
    					bit1_del(x, ops[t].w);
    			}
    		}
    		else
    		{
    			ops[t] = op[j++];
    
    			if (ops[t].type == 1)
    				bit1_query2(ops[t].y, ops[t].id);
    			else if (ops[t].type == 2)
    				bit1_query1(ops[t].y, ops[t].id);
    			else if (ops[t].type == 3)
    				bit0_query2(ops[t].y, ops[t].id);
    			else if (ops[t].type == 4)
    				bit0_query1(ops[t].y, ops[t].id);
    		}
    	}
    
    	for (int t = l; t <= r; ++t)
    	{
    		op[t] = ops[t];
    		if (op[t].type >= 5)
    			bit_clear(op[t].y);
    	}
    }
    
    int main()
    {
    //	freopen("simple.in", "r", stdin);
    //	freopen("simple.out", "w", stdout);
    
    	srand(time(NULL));
    
    	cin >> n >> nQ;
    	for (int i = 1; i <= n; ++i)
    		val[i] = col[i] = getint();
    
    	build_treap();
    	build_ops();
    
    	solve(1, op_n);
    
    	for (int i = 1; i <= m; ++i)
    	{
    		if (opt[i] == 5)
    			printf("%d\n", res1[i]);
    		else if (opt[i] == 1)
    		{
    			int sum = (s64)res1[i];
    			sum = (s64)sum * res1[i] % M;
    			sum = (s64)sum * res1[i] % M;
    			sum = (sum - 3ll * res1[i] * res2[i]) % M;
    			sum = (sum + 2ll * res3[i]) % M;
    			sum = (s64)sum * I % M;
    			printf("%d\n", (sum + M) % M);
    		}
    	}
    
    	fclose(stdin);
    	fclose(stdout);
    	return 0;
    }
    
    • 1

    信息

    ID
    1026
    时间
    3000ms
    内存
    512MiB
    难度
    (无)
    标签
    递交数
    0
    已通过
    0
    上传者