Context manager for micro-batching synchronization using threading events.
  Source code in vllm/v1/worker/ubatching.py
  | class UBatchContext:
    """
    Context manager for micro-batching synchronization using threading events.
    """
    def __init__(
        self,
        id: int,
        comm_stream: torch.cuda.Stream,
        compute_stream: torch.cuda.Stream,
        forward_context: ForwardContext,
        ready_barrier: threading.Barrier,
        cpu_wait_event: threading.Event,
        cpu_signal_event: threading.Event,
        gpu_comm_done_event: torch.cuda.Event,
        gpu_compute_done_event: torch.cuda.Event,
        schedule: str = "default",
    ):
        self.id = id
        self.comm_stream = comm_stream
        self.compute_stream = compute_stream
        self.forward_context = forward_context
        self.ready_barrier = ready_barrier
        self.cpu_wait_event = cpu_wait_event
        self.cpu_signal_event = cpu_signal_event
        self.current_stream = compute_stream
        self.gpu_comm_done_event = gpu_comm_done_event
        self.gpu_compute_done_event = gpu_compute_done_event
        self.schedule = schedule
        self.recv_hook = None
    def __enter__(self):
        global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
        _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
        _CURRENT_CONTEXTS[self.id] = self
        self.ready_barrier.wait()
        self.cpu_wait_event.wait()
        self.cpu_wait_event.clear()
        self._restore_context()
        # Assume we want to start on the compute stream
        self.update_stream(self.compute_stream)
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
        _CURRENT_CONTEXTS[self.id] = None
        del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
        self.maybe_run_recv_hook()
        self.cpu_signal_event.set()
        self.cpu_wait_event.clear()
        return False
    def _restore_context(self):
        forward_context._forward_context = self.forward_context
    def update_stream(self, stream):
        self.current_stream = stream
        if current_stream() != self.current_stream:
            torch.cuda.set_stream(self.current_stream)
    def _signal_comm_done(self):
        self.gpu_comm_done_event.record(self.comm_stream)
    def _signal_compute_done(self):
        self.gpu_compute_done_event.record(self.compute_stream)
    def _wait_compute_done(self):
        self.comm_stream.wait_event(self.gpu_compute_done_event)
    def _wait_comm_done(self):
        self.compute_stream.wait_event(self.gpu_comm_done_event)
    def _cpu_yield(self):
        # It is critical for correctness that only one thread is running
        # at a time. These asserts just make sure that this is the only
        # thread running before waking the other one up and going to sleep
        assert forward_context._forward_context == self.forward_context
        assert current_stream() == self.current_stream
        assert not self.cpu_wait_event.is_set()
        self.cpu_signal_event.set()
        self.cpu_wait_event.wait()
        self.cpu_wait_event.clear()
        self._restore_context()
    def switch_to_comm(self):
        self.update_stream(self.comm_stream)
    def switch_to_compute(self):
        self.update_stream(self.compute_stream)
    def switch_to_comm_sync(self):
        self._signal_compute_done()
        self.update_stream(self.comm_stream)
        self._wait_compute_done()
    def switch_to_compute_sync(self):
        self._signal_comm_done()
        self.update_stream(self.compute_stream)
        self._wait_comm_done()
    def maybe_run_recv_hook(self):
        if self.recv_hook is not None:
            self.recv_hook()
            self.recv_hook = None
    def yield_(self):
        self.current_stream = current_stream()
        self._cpu_yield()
        self.update_stream(self.current_stream)
    def yield_and_switch_from_compute_to_comm(self):
        assert current_stream() == self.compute_stream
        self._signal_compute_done()
        self._cpu_yield()
        assert self.current_stream == self.compute_stream
        self.update_stream(self.comm_stream)
        self._wait_compute_done()
    def yield_and_switch_from_comm_to_compute(self):
        assert current_stream() == self.comm_stream
        self._signal_comm_done()
        self._cpu_yield()
        assert self.current_stream == self.comm_stream
        self.update_stream(self.compute_stream)
        self._wait_comm_done()
  | 
        comm_stream  instance-attribute  
 comm_stream = comm_stream
   
      compute_stream  instance-attribute  
 compute_stream = compute_stream
   
      cpu_signal_event  instance-attribute  
 cpu_signal_event = cpu_signal_event
   
      cpu_wait_event  instance-attribute  
 cpu_wait_event = cpu_wait_event
   
      current_stream  instance-attribute  
 current_stream = compute_stream
   
      forward_context  instance-attribute  
 forward_context = forward_context
   
      gpu_comm_done_event  instance-attribute  
 gpu_comm_done_event = gpu_comm_done_event
   
      gpu_compute_done_event  instance-attribute  
 gpu_compute_done_event = gpu_compute_done_event
   
       ready_barrier  instance-attribute  
 ready_barrier = ready_barrier
   
      recv_hook  instance-attribute  
   
      schedule  instance-attribute  
   
      __enter__ 
    Source code in vllm/v1/worker/ubatching.py
  | def __enter__(self):
    global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
    _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
    _CURRENT_CONTEXTS[self.id] = self
    self.ready_barrier.wait()
    self.cpu_wait_event.wait()
    self.cpu_wait_event.clear()
    self._restore_context()
    # Assume we want to start on the compute stream
    self.update_stream(self.compute_stream)
    return self
  | 
           __exit__ 
 __exit__(exc_type, exc_val, exc_tb)
    Source code in vllm/v1/worker/ubatching.py
  | def __exit__(self, exc_type, exc_val, exc_tb):
    global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
    _CURRENT_CONTEXTS[self.id] = None
    del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
    self.maybe_run_recv_hook()
    self.cpu_signal_event.set()
    self.cpu_wait_event.clear()
    return False
  | 
           __init__ 
 __init__(
    id: int,
    comm_stream: Stream,
    compute_stream: Stream,
    forward_context: ForwardContext,
    ready_barrier: Barrier,
    cpu_wait_event: Event,
    cpu_signal_event: Event,
    gpu_comm_done_event: Event,
    gpu_compute_done_event: Event,
    schedule: str = "default",
)
    Source code in vllm/v1/worker/ubatching.py
  | def __init__(
    self,
    id: int,
    comm_stream: torch.cuda.Stream,
    compute_stream: torch.cuda.Stream,
    forward_context: ForwardContext,
    ready_barrier: threading.Barrier,
    cpu_wait_event: threading.Event,
    cpu_signal_event: threading.Event,
    gpu_comm_done_event: torch.cuda.Event,
    gpu_compute_done_event: torch.cuda.Event,
    schedule: str = "default",
):
    self.id = id
    self.comm_stream = comm_stream
    self.compute_stream = compute_stream
    self.forward_context = forward_context
    self.ready_barrier = ready_barrier
    self.cpu_wait_event = cpu_wait_event
    self.cpu_signal_event = cpu_signal_event
    self.current_stream = compute_stream
    self.gpu_comm_done_event = gpu_comm_done_event
    self.gpu_compute_done_event = gpu_compute_done_event
    self.schedule = schedule
    self.recv_hook = None
  | 
           _cpu_yield 
    Source code in vllm/v1/worker/ubatching.py
  | def _cpu_yield(self):
    # It is critical for correctness that only one thread is running
    # at a time. These asserts just make sure that this is the only
    # thread running before waking the other one up and going to sleep
    assert forward_context._forward_context == self.forward_context
    assert current_stream() == self.current_stream
    assert not self.cpu_wait_event.is_set()
    self.cpu_signal_event.set()
    self.cpu_wait_event.wait()
    self.cpu_wait_event.clear()
    self._restore_context()
  | 
           _restore_context 
    Source code in vllm/v1/worker/ubatching.py
  | def _restore_context(self):
    forward_context._forward_context = self.forward_context
  | 
           _signal_comm_done 
    Source code in vllm/v1/worker/ubatching.py
  | def _signal_comm_done(self):
    self.gpu_comm_done_event.record(self.comm_stream)
  | 
           _signal_compute_done 
    Source code in vllm/v1/worker/ubatching.py
  | def _signal_compute_done(self):
    self.gpu_compute_done_event.record(self.compute_stream)
  | 
           _wait_comm_done 
    Source code in vllm/v1/worker/ubatching.py
  | def _wait_comm_done(self):
    self.compute_stream.wait_event(self.gpu_comm_done_event)
  | 
           _wait_compute_done 
    Source code in vllm/v1/worker/ubatching.py
  | def _wait_compute_done(self):
    self.comm_stream.wait_event(self.gpu_compute_done_event)
  | 
           maybe_run_recv_hook 
    Source code in vllm/v1/worker/ubatching.py
  | def maybe_run_recv_hook(self):
    if self.recv_hook is not None:
        self.recv_hook()
        self.recv_hook = None
  | 
           switch_to_comm 
    Source code in vllm/v1/worker/ubatching.py
  | def switch_to_comm(self):
    self.update_stream(self.comm_stream)
  | 
           switch_to_comm_sync 
    Source code in vllm/v1/worker/ubatching.py
  | def switch_to_comm_sync(self):
    self._signal_compute_done()
    self.update_stream(self.comm_stream)
    self._wait_compute_done()
  | 
           switch_to_compute 
    Source code in vllm/v1/worker/ubatching.py
  | def switch_to_compute(self):
    self.update_stream(self.compute_stream)
  | 
           switch_to_compute_sync 
    Source code in vllm/v1/worker/ubatching.py
  | def switch_to_compute_sync(self):
    self._signal_comm_done()
    self.update_stream(self.compute_stream)
    self._wait_comm_done()
  | 
           update_stream 
    Source code in vllm/v1/worker/ubatching.py
  | def update_stream(self, stream):
    self.current_stream = stream
    if current_stream() != self.current_stream:
        torch.cuda.set_stream(self.current_stream)
  | 
           yield_ 
    Source code in vllm/v1/worker/ubatching.py
  | def yield_(self):
    self.current_stream = current_stream()
    self._cpu_yield()
    self.update_stream(self.current_stream)
  | 
           yield_and_switch_from_comm_to_compute 
 yield_and_switch_from_comm_to_compute()
    Source code in vllm/v1/worker/ubatching.py
  | def yield_and_switch_from_comm_to_compute(self):
    assert current_stream() == self.comm_stream
    self._signal_comm_done()
    self._cpu_yield()
    assert self.current_stream == self.comm_stream
    self.update_stream(self.compute_stream)
    self._wait_comm_done()
  | 
           yield_and_switch_from_compute_to_comm 
 yield_and_switch_from_compute_to_comm()
    Source code in vllm/v1/worker/ubatching.py
  | def yield_and_switch_from_compute_to_comm(self):
    assert current_stream() == self.compute_stream
    self._signal_compute_done()
    self._cpu_yield()
    assert self.current_stream == self.compute_stream
    self.update_stream(self.comm_stream)
    self._wait_compute_done()
  |