update cpu cache load use async way.#1318
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the CPU KV cache management system by decoupling 'load' and 'offload' operations. It introduces distinct task statuses and queues for each operation type within the InferReq and multi_level_kv_cache modules. Key feedback includes a critical fix for a potential UnboundLocalError in update_cpu_cache_load_task_states where need_free_page_list might be undefined, and a recommendation to correctly utilize the cpu_kv_cache_stream in _start_kv_cache_load_task to ensure proper GPU synchronization.
| def update_cpu_cache_load_task_states(self): | ||
| if self.backend.is_master_in_dp: | ||
| trans_ok_tasks = [] | ||
| while len(self.cpu_cache_load_task_handle_queue) != 0: | ||
| task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft() | ||
| if task.sync_event.query(): | ||
| trans_ok_tasks.append(task) | ||
| else: | ||
| self.cpu_cache_load_task_handle_queue.appendleft(task) | ||
| break | ||
| item_size = len(trans_ok_tasks) | ||
| dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) | ||
| else: | ||
| recv_list = [None] | ||
| dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) | ||
| item_size = recv_list[0] | ||
| trans_ok_tasks: List[LoadTransTask] = [ | ||
| self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size) | ||
| ] | ||
|
|
||
| if item_size > 0: | ||
| need_free_page_list = [] |
There was a problem hiding this comment.
The variable need_free_page_list is initialized inside the if item_size > 0: block at line 411. If item_size is 0, this variable will be undefined when accessed at line 428 on the master rank, leading to an UnboundLocalError. It should be initialized at the beginning of the method to ensure it is always defined.
def update_cpu_cache_load_task_states(self):
need_free_page_list = []
if self.backend.is_master_in_dp:
trans_ok_tasks = []
while len(self.cpu_cache_load_task_handle_queue) != 0:
task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft()
if task.sync_event.query():
trans_ok_tasks.append(task)
else:
self.cpu_cache_load_task_handle_queue.appendleft(task)
break
item_size = len(trans_ok_tasks)
dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0)
else:
recv_list = [None]
dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0)
item_size = recv_list[0]
trans_ok_tasks: List[LoadTransTask] = [
self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size)
]
if item_size > 0:| if self.backend.radix_cache is not None: | ||
| g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) | ||
|
|
||
| # 计算需要加载的页面(只加载未匹配的部分) | ||
| ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len) | ||
| assert ready_page_num <= len(page_list) | ||
| need_pages = page_list[ready_page_num:] # 只取需要的页面 | ||
|
|
||
| mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) | ||
|
|
||
| if self.need_sync_compute_stream(): | ||
| # TODO fa3 现在必须使用同步模式, 未来需要移除 | ||
| torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) | ||
| # g_infer_context.get_overlap_stream().synchronize() | ||
|
|
||
| mem_manager = self.backend.model.mem_manager | ||
| req_manager = self.backend.model.req_manager | ||
|
|
||
| mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) | ||
| page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True) | ||
| # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, | ||
| # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 | ||
| # 这里需要进行pad操作,使操作的页面是完整的。 | ||
| _start = page_len_start_list[ready_page_num] | ||
|
|
||
| _end = req.cur_kv_len | ||
| assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" | ||
| mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda]) | ||
|
|
||
| assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] | ||
|
|
||
| # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作, | ||
| # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。 | ||
| req.cur_kv_len = req.cur_kv_len + need_token_num | ||
|
|
||
| mem_manager.operator.load_cpu_cache_to_gpu( | ||
| mem_indexes=mem_indexes_cuda, | ||
| page_indexes=page_indexes_cuda, | ||
| cpu_cache_client=self.cpu_cache_client, | ||
| req=req, | ||
| ) | ||
|
|
||
| sync_event = torch.cuda.Event() | ||
| sync_event.record() | ||
|
|
||
| trans_task = LoadTransTask( | ||
| req_obj=req, | ||
| page_list=page_list, | ||
| mem_indexes=mem_indexes, | ||
| sync_event=sync_event, | ||
| ) | ||
| return trans_task |
There was a problem hiding this comment.
The _start_kv_cache_load_task method should execute the load_cpu_cache_to_gpu call and the event recording within the cpu_kv_cache_stream context. Currently, it ignores the passed stream, which may lead to synchronization issues if the operation is launched on the default stream while the rest of the system expects it to be on the dedicated cache stream.
if self.backend.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)
# 计算需要加载的页面(只加载未匹配的部分)
ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len)
assert ready_page_num <= len(page_list)
need_pages = page_list[ready_page_num:] # 只取需要的页面
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num)
if self.need_sync_compute_stream():
# TODO fa3 现在必须使用同步模式, 未来需要移除
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
# g_infer_context.get_overlap_stream().synchronize()
mem_manager = self.backend.model.mem_manager
req_manager = self.backend.model.req_manager
with torch.cuda.stream(cpu_kv_cache_stream):
mem_indexes_cuda = mem_indexes.cuda(non_blocking=True)
page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True)
# 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做,
# 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以
# 这里需要进行pad操作,使操作的页面是完整的。
_start = page_len_start_list[ready_page_num]
_end = req.cur_kv_len
assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]"
mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda])
assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num]
# 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作,
# 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。
req.cur_kv_len = req.cur_kv_len + need_token_num
mem_manager.operator.load_cpu_cache_to_gpu(
mem_indexes=mem_indexes_cuda,
page_indexes=page_indexes_cuda,
cpu_cache_client=self.cpu_cache_client,
req=req,
)
sync_event = torch.cuda.Event()
sync_event.record()
trans_task = LoadTransTask(
req_obj=req,
page_list=page_list,
mem_indexes=mem_indexes,
sync_event=sync_event,
)
return trans_task
No description provided.