分享|一道高级逆序对问题, 区间修改, 区间查询
863
2024.08.09
2024.08.09
发布于 中国

降阶问题 : LCR 170. 交易逆序对的总数
本问题: 给出一个列表, 如 record = [9, 7, 5, 4, 6]
区间修改操作(q[0] = 1): 将[l,r]内的值改成 val
区间查询操作(q[0] = 2): 返回[l,r]内的逆序对个数

使用线段树解答,
1, 设计结点单位元initval, 其中包括三项 a, invCnt 记录结点中逆序对的个数 b, freq 记录结点中每个值的出现次数 c,len 记录结点的长度

class Item:
    def __init__(self, maxV):
        self.invCnt = 0
        self.freq = [0] * maxV
        self.len = 0

m = 10
initval = Item(m)   # m 为 nums中最大值+1  根据题意来; 由于每个数据元都带一个长为m的数组,所以确实不大  ps: n为nums的长度

2, 设计lazy单位元initlazy

initlazy = 114514   # 大于m  取不到就行

3, 设计合并结点的操作op
长度和数值的次数直接相加就行了 合并后的逆序对个数利用前缀和计算, 可以O(m)完成

def op(n1, n2):
    nn = Item(m)
    # print(n1, n2)
    nn.invCnt = n1.invCnt + n2.invCnt
    nn.len = n1.len + n2.len
    prev = [0] * m
    prev[0] = n2.freq[0]
    for i in range(m):
        nn.freq[i] = n1.freq[i] + n2.freq[i]
        if i != 0:
            nn.invCnt += n1.freq[i] * prev[i - 1]
            prev[i] = prev[i - 1] + n2.freq[i]
    return nn

4, 设lazy标记对结点的更新app
逆序对数直接归零, 值的次数直接等于长度即可

def app(lz, n1):
    if lz == initlazy:
        return n1
    nn = Item(m)
    nn.freq[lz] = n1.len
    nn.invCnt = 0
    nn.len = n1.len
    return nn

5, 设计lazy标记的合并
貌似l1一直是新的那个,所以直接赋值成l1就行

def com(l1, l2):
    # return l1
    if l1 == initlazy: return l2
    return l1

总体程序, 感谢FatalError提供的修改版atcoder lazy segtree模板

class LST:
    __slots__ = 'n', 'height', 'size', 'initval', 'initlazy', 'op', 'apply', 'compose', 'tree', 'lazy'
    def __init__(self, nums, initval, initlazy, op, apply, compose):
        if isinstance(nums, int):
            nums = [initval] * nums
        self.n = len(nums)
        self.height = (self.n-1).bit_length()
        self.size = 1 << self.height
        self.initval = initval
        self.initlazy = initlazy
        self.op = op
        self.apply = apply
        self.compose = compose
        self.tree = [initval for _ in range(2 * self.size)]
        self.tree[self.size:self.size+self.n] = nums
        for i in range(self.size-1, 0, -1):
            self.pushup(i)
        self.lazy = [initlazy for _ in range(self.size)]

    def pushup(self, rt):
        self.tree[rt] = self.op(self.tree[rt*2], self.tree[rt*2+1])

    def pushdown(self, rt):
        if self.lazy[rt] == self.initlazy: return  ##
        self.modify(rt*2, self.lazy[rt])
        self.modify(rt*2+1, self.lazy[rt])
        self.lazy[rt] = self.initlazy

    def modify(self, rt, val):
        self.tree[rt] = self.apply(val, self.tree[rt])
        if rt < self.size:
            self.lazy[rt] = self.compose(val, self.lazy[rt])

    def set(self, idx, val):
        idx += self.size
        for i in range(self.height, 0, -1):
            self.pushdown(idx >> i)
        self.tree[idx] = val
        for i in range(1, self.height + 1):
            self.pushup(idx >> i)

    def update(self, left, right, val):
        if left > right: return
        left += self.size
        right += self.size
        for i in range(self.height, 0, -1):
            if left >> i << i != left:
                self.pushdown(left >> i)
            if (right+1) >> i << i != right+1:
                self.pushdown(right >> i)
        l, r = left, right
        while left <= right:
            if left & 1:
                self.modify(left, val)
                left += 1
            if not right & 1:
                self.modify(right, val)
                right -= 1
            left >>= 1
            right >>= 1
        left, right = l, r
        for i in range(1, self.height + 1):
            if left >> i << i != left:
                self.pushup(left >> i)
            if (right+1) >> i << i != right+1:
                self.pushup(right >> i)

    def get(self, idx):
        idx += self.size
        for i in range(self.height, 0, -1):
            self.pushdown(idx >> i)
        return self.tree[idx]

    def query(self, left, right):
        if left > right: return self.initval
        left += self.size
        right += self.size
        for i in range(self.height, 0, -1):
            if left >> i << i != left:
                self.pushdown(left >> i)
            if (right+1) >> i << i != right+1:
                self.pushdown(right >> i)
        lres, rres = self.initval, self.initval
        while left <= right:
            if left & 1:
                lres = self.op(lres, self.tree[left])
                left += 1
            if not right & 1:
                rres = self.op(self.tree[right], rres)
                right -= 1
            left >>= 1
            right >>= 1
        return self.op(lres, rres)

    def all(self):
        return self.tree[1]

    def bisect_left(self, left, f):
        # 查找 left 右侧首个满足 f(query(left, idx)) 为真的下标
        left += self.size
        lres = self.initval
        for i in range(self.height, 0, -1):
            self.pushdown(left >> i)

        while True:
            while not left & 1:
                left >>= 1
            if f(self.op(lres, self.tree[left])):
                while left < self.size:
                    self.pushdown(left)
                    left *= 2
                    if not f(self.op(lres, self.tree[left])):
                        lres = self.op(lres, self.tree[left])
                        left += 1
                return left - self.size
            if left & (left + 1) == 0:
                return self.n
            lres = self.op(lres, self.tree[left])
            left += 1

    def bisect_right(self, right, f):
        # 查找 right 左侧首个满足 f(query(idx, right)) 为真的下标
        right += self.size
        rres = self.initval
        for i in range(self.height, 0, -1):
            self.pushdown(right >> i)

        while True:
            while right > 1 and right & 1:
                right >>= 1
            if f(self.op(self.tree[right], rres)):
                while right < self.size:
                    self.pushdown(right)
                    right = 2 * right + 1
                    if not f(self.op(self.tree[right], rres)):
                        rres = self.op(self.tree[right], rres)
                        right -= 1
                return right - self.size
            if right & (right - 1) == 0:
                return -1
            rres = self.op(self.tree[right], rres)
            right -= 1

    def __str__(self):
        return str([[self.get(i).freq] for i in range(self.n)])


class Item:
    def __init__(self, maxV):
        self.invCnt = 0
        self.freq = [0] * maxV
        self.len = 0

m = 10
initval = Item(m)   # 50 为 nums中最大值+1  根据题意来; 由于每个数据元都带一个长为m的数组,所以确实不大 ps: n为nums的长度
initlazy = 114514   # 大于m  取不到就行
def op(n1, n2):
    nn = Item(m)
    # print(n1, n2)
    nn.invCnt = n1.invCnt + n2.invCnt
    nn.len = n1.len + n2.len
    prev = [0] * m
    prev[0] = n2.freq[0]
    for i in range(m):
        nn.freq[i] = n1.freq[i] + n2.freq[i]
        if i != 0:
            nn.invCnt += n1.freq[i] * prev[i - 1]
            prev[i] = prev[i - 1] + n2.freq[i]
    return nn

def app(lz, n1):
    if lz == initlazy:
        return n1
    nn = Item(m)
    nn.freq[lz] = n1.len
    nn.invCnt = 0
    nn.len = n1.len
    return nn

def com(l1, l2):
    return l1
    if l1 == initlazy: return l2
    # if l2 == initlazy: return l1
    return l1

def main():
    arr = [1, 2, 3, 6, 5, 4]
    queries = [
        [1, 1, 3],
        [1, 2, 5],
        # [2, 2, 4, 8],
        [2, 2, 3, 7],
        [2, 2, 2, 8],
        [1, 2, 4],
        [2, 1, 4, 0],
        [2, 1, 4, 1],
        [2, 1, 4, 2],
        [2, 1, 4, 1],
        # [2, 1, 4, 3],
        [1, 1, 6]


    ]
    def f(a, l = 1)->Item:
        res = Item(m)
        res.freq[a] = l
        res.len = l
        return res

    a = [f(a) for a in arr]
    st = LST(a, initval, initlazy, op, app, com)
    print(st)
    for query in queries:
        type, *q = query
        if type == 1:
            l , r = q
            l -= 1
            r -= 1
            print(st.query(l, r).invCnt)
        elif type == 2:
            l,r,v = q
            l -= 1
            r -= 1
            # st.set(l, f(v))
            st.update(l, r, v)
    print(st)
main()
评论 (4)