Skip to content

update cpu cache load use async way.#1318

Open
hiworldwzj wants to merge 10 commits into
mainfrom
wzj_dev
Open

update cpu cache load use async way.#1318
hiworldwzj wants to merge 10 commits into
mainfrom
wzj_dev

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the CPU KV cache management system by decoupling 'load' and 'offload' operations. It introduces distinct task statuses and queues for each operation type within the InferReq and multi_level_kv_cache modules. Key feedback includes a critical fix for a potential UnboundLocalError in update_cpu_cache_load_task_states where need_free_page_list might be undefined, and a recommendation to correctly utilize the cpu_kv_cache_stream in _start_kv_cache_load_task to ensure proper GPU synchronization.

Comment on lines +390 to +411
def update_cpu_cache_load_task_states(self):
if self.backend.is_master_in_dp:
trans_ok_tasks = []
while len(self.cpu_cache_load_task_handle_queue) != 0:
task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft()
if task.sync_event.query():
trans_ok_tasks.append(task)
else:
self.cpu_cache_load_task_handle_queue.appendleft(task)
break
item_size = len(trans_ok_tasks)
dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0)
else:
recv_list = [None]
dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0)
item_size = recv_list[0]
trans_ok_tasks: List[LoadTransTask] = [
self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size)
]

if item_size > 0:
need_free_page_list = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable need_free_page_list is initialized inside the if item_size > 0: block at line 411. If item_size is 0, this variable will be undefined when accessed at line 428 on the master rank, leading to an UnboundLocalError. It should be initialized at the beginning of the method to ensure it is always defined.

    def update_cpu_cache_load_task_states(self):
        need_free_page_list = []
        if self.backend.is_master_in_dp:
            trans_ok_tasks = []
            while len(self.cpu_cache_load_task_handle_queue) != 0:
                task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft()
                if task.sync_event.query():
                    trans_ok_tasks.append(task)
                else:
                    self.cpu_cache_load_task_handle_queue.appendleft(task)
                    break
            item_size = len(trans_ok_tasks)
            dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0)
        else:
            recv_list = [None]
            dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0)
            item_size = recv_list[0]
            trans_ok_tasks: List[LoadTransTask] = [
                self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size)
            ]

        if item_size > 0:

Comment on lines +134 to +185
if self.backend.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)

# 计算需要加载的页面(只加载未匹配的部分)
ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len)
assert ready_page_num <= len(page_list)
need_pages = page_list[ready_page_num:] # 只取需要的页面

mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num)

if self.need_sync_compute_stream():
# TODO fa3 现在必须使用同步模式, 未来需要移除
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
# g_infer_context.get_overlap_stream().synchronize()

mem_manager = self.backend.model.mem_manager
req_manager = self.backend.model.req_manager

mem_indexes_cuda = mem_indexes.cuda(non_blocking=True)
page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True)
# 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做,
# 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以
# 这里需要进行pad操作,使操作的页面是完整的。
_start = page_len_start_list[ready_page_num]

_end = req.cur_kv_len
assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]"
mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda])

assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num]

# 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作,
# 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。
req.cur_kv_len = req.cur_kv_len + need_token_num

mem_manager.operator.load_cpu_cache_to_gpu(
mem_indexes=mem_indexes_cuda,
page_indexes=page_indexes_cuda,
cpu_cache_client=self.cpu_cache_client,
req=req,
)

sync_event = torch.cuda.Event()
sync_event.record()

trans_task = LoadTransTask(
req_obj=req,
page_list=page_list,
mem_indexes=mem_indexes,
sync_event=sync_event,
)
return trans_task
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _start_kv_cache_load_task method should execute the load_cpu_cache_to_gpu call and the event recording within the cpu_kv_cache_stream context. Currently, it ignores the passed stream, which may lead to synchronization issues if the operation is launched on the default stream while the rest of the system expects it to be on the dedicated cache stream.

        if self.backend.radix_cache is not None:
            g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)

        # 计算需要加载的页面(只加载未匹配的部分)
        ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len)
        assert ready_page_num <= len(page_list)
        need_pages = page_list[ready_page_num:]  # 只取需要的页面

        mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num)

        if self.need_sync_compute_stream():
            # TODO fa3 现在必须使用同步模式, 未来需要移除
            torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
            # g_infer_context.get_overlap_stream().synchronize()

        mem_manager = self.backend.model.mem_manager
        req_manager = self.backend.model.req_manager

        with torch.cuda.stream(cpu_kv_cache_stream):
            mem_indexes_cuda = mem_indexes.cuda(non_blocking=True)
            page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True)
            # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做,
            # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以
            # 这里需要进行pad操作,使操作的页面是完整的。
            _start = page_len_start_list[ready_page_num]

            _end = req.cur_kv_len
            assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]"
            mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda])

            assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num]

            # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作,
            # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。
            req.cur_kv_len = req.cur_kv_len + need_token_num

            mem_manager.operator.load_cpu_cache_to_gpu(
                mem_indexes=mem_indexes_cuda,
                page_indexes=page_indexes_cuda,
                cpu_cache_client=self.cpu_cache_client,
                req=req,
            )

            sync_event = torch.cuda.Event()
            sync_event.record()

        trans_task = LoadTransTask(
            req_obj=req,
            page_list=page_list,
            mem_indexes=mem_indexes,
            sync_event=sync_event,
        )
        return trans_task

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant