SortedList 是处理有序集合的强大工具,本指南将全面介绍从基础到高级的所有用法,特别是处理复杂数据类型(如元组)的场景。
from sortedcontainers import SortedList
# 创建空列表
sl1 = SortedList()
# 从可迭代对象创建
sl2 = SortedList([3, 1, 4, 2])
# 使用生成器创建
sl3 = SortedList(x for x in range(10) if x % 2 == 0)# 添加元素
sl.add(5) # O(log n)
# 批量添加
sl.update([6, 2, 8]) # O(k log n)
# 删除元素
sl.discard(2) # 安全删除 O(log n)
sl.remove(3) # 必须存在 O(log n)
sl.pop() # 删除末尾 O(1)
# 查询长度
len(sl) # O(1)# 创建含元组的SortedList
coordinates = SortedList([(2, 3), (1, 5), (3, 1)])
# 自动按元组规则排序
print(coordinates) # [(1, 5), (2, 3), (3, 1)]
# 添加新坐标
coordinates.add((2, 2))
# 结果: [(1, 5), (2, 2), (2, 3), (3, 1)]# 查找第一个元素≥2的元组
idx = coordinates.bisect_left((2,))
print(coordinates[idx:]) # [(2, 2), (2, 3), (3, 1)]
# 查找第一个元素>2且第二个元素≥1的元组
idx = coordinates.bisect_left((2, 3))
print(coordinates[idx:]) # [(2, 3), (3, 1)]# 三维坐标示例
space = SortedList([(1, 2, 3), (1, 1, 5), (2, 0, 0)])
# 添加时自动排序
space.add((1, 2, 2))
# 结果: [(1, 1, 5), (1, 2, 2), (1, 2, 3), (2, 0, 0)]
# 范围查询
start = space.bisect_left((1, 2))
end = space.bisect_right((1, 2, 3))
print(space[start:end]) # [(1, 2, 2), (1, 2, 3)]from operator import itemgetter
# 按元组第二个元素排序
sl = SortedList([('a', 3), ('b', 1), ('c', 2)], key=itemgetter(1))
print(sl) # [('b', 1), ('c', 2), ('a', 3)]
# 添加新元素会自动按key排序
sl.add(('d', 1.5))
# 结果: [('b', 1), ('d', 1.5), ('c', 2), ('a', 3)]# 按字符串长度和字母顺序排序
words = SortedList(key=lambda x: (len(x), x))
words.update(['apple', 'banana', 'pear', 'orange'])
print(words) # ['pear', 'apple', 'banana', 'orange']# 使用负号实现数字降序
numbers = SortedList(key=lambda x: -x)
numbers.update([3, 1, 4, 2])
print(numbers) # [4, 3, 2, 1]# 不好: O(n log n)
for x in large_list:
sl.add(x)
# 好: O(n log n) 但常数更小
sl.update(large_list)# 创建时预估大小
sl = SortedList(load=1000) # 预分配1000个元素的空间# 高效合并两个SortedList
sl1 = SortedList([1, 3, 5])
sl2 = SortedList([2, 4, 6])
merged = SortedList(sl1 + sl2) # O(n)# 创建示例数据
sl = SortedList([x for x in range(100)])
# 查询10-20之间的元素
start = sl.bisect_left(10)
end = sl.bisect_right(20)
print(sl[start:end]) # [10, 11, ..., 20]def find_closest(sl, target):
idx = sl.bisect_left(target)
candidates = []
if idx > 0:
candidates.append(sl[idx-1])
if idx < len(sl):
candidates.append(sl[idx])
return min(candidates, key=lambda x: abs(x - target))
print(find_closest(sl, 17.5)) # 17或18中更接近的# 自定义key处理None
sl = SortedList([1, None, 3, 2], key=lambda x: float('inf') if x is None else x)
print(sl) # [1, 2, 3, None]# 统一为字符串比较
sl = SortedList([1, '2', 3.0], key=str)
print(sl) # [1, 3.0, '2']class StockMonitor:
def __init__(self):
self.prices = SortedList()
self.timestamps = {}
def add_price(self, timestamp, price):
self.prices.add(price)
self.timestamps[price] = timestamp
def get_price_range(self, min_p, max_p):
start = self.prices.bisect_left(min_p)
end = self.prices.bisect_right(max_p)
return [(self.timestamps[p], p) for p in self.prices[start:end]]class ScheduleManager:
def __init__(self):
self.events = SortedList(key=lambda x: x['time'])
def add_event(self, name, time, duration):
self.events.add({'name': name, 'time': time, 'duration': duration})
def get_conflicts(self):
conflicts = []
for i in range(1, len(self.events)):
prev = self.events[i-1]
curr = self.events[i]
if prev['time'] + prev['duration'] > curr['time']:
conflicts.append((prev, curr))
return conflicts