Skip to content

vllm.entrypoints.pooling.pooling.serving

logger module-attribute

logger = init_logger(__name__)

OpenAIServingPooling

Bases: OpenAIServing

Source code in vllm/entrypoints/pooling/pooling/serving.py
class OpenAIServingPooling(OpenAIServing):
    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template

    async def create_pooling(
        self,
        request: PoolingRequest,
        raw_request: Request | None = None,
    ) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
        """
        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        model_name = self.models.model_name()

        request_id = f"pool-{self._base_request_id(raw_request)}"
        created_time = int(time.time())

        is_io_processor_request = isinstance(request, IOProcessorRequest)
        try:
            lora_request = self._maybe_get_adapters(request)

            if getattr(request, "dimensions", None) is not None:
                return self.create_error_response(
                    "dimensions is currently not supported"
                )

            engine_prompts: Sequence[PromptType | TokPrompt]
            if is_io_processor_request:
                if self.io_processor is None:
                    raise ValueError(
                        "No IOProcessor plugin installed. Please refer "
                        "to the documentation and to the "
                        "'prithvi_geospatial_mae_io_processor' "
                        "offline inference example for more details."
                    )

                validated_prompt = self.io_processor.parse_request(request)

                raw_prompts = await self.io_processor.pre_process_async(
                    prompt=validated_prompt, request_id=request_id
                )
                engine_prompts = prompt_to_seq(raw_prompts)
            elif isinstance(request, PoolingChatRequest):
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
                    trust_request_chat_template=self.trust_request_chat_template,
                )
                if error_check_ret is not None:
                    return error_check_ret

                _, engine_prompts = await self._preprocess_chat(
                    request,
                    request.messages,
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=None,
                )
            elif isinstance(request, PoolingCompletionRequest):
                engine_prompts = await self._preprocess_completion(
                    request,
                    prompt_input=request.input,
                    prompt_embeds=None,
                )
            else:
                raise ValueError(f"Unsupported request of type {type(request)}")
        except (ValueError, TypeError, jinja2.TemplateError) as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

        # Schedule the request and get the result generator.
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        try:
            if is_io_processor_request:
                assert self.io_processor is not None and isinstance(
                    request, IOProcessorRequest
                )
                pooling_params = self.io_processor.validate_or_generate_params()
            else:
                pooling_params = request.to_pooling_params()

            for i, engine_prompt in enumerate(engine_prompts):
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(
                    request_id_item,
                    engine_prompt,
                    params=pooling_params,
                    lora_request=lora_request,
                )

                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )

                if is_io_processor_request:
                    tokenization_kwargs: dict[str, Any] = {}
                else:
                    tok_params = request.build_tok_params(self.model_config)  # type: ignore
                    tokenization_kwargs = tok_params.get_encode_kwargs()

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=lora_request,
                    tokenization_kwargs=tokenization_kwargs,
                    trace_headers=trace_headers,
                    priority=request.priority,
                )

                generators.append(generator)
        except ValueError as e:
            return self.create_error_response(e)

        result_generator = merge_async_iterators(*generators)

        if is_io_processor_request:
            assert self.io_processor is not None
            output = await self.io_processor.post_process_async(
                model_output=result_generator,
                request_id=request_id,
            )
            return self.io_processor.output_to_response(output)

        assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
        num_prompts = len(engine_prompts)

        # Non-streaming response
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res

            assert all(final_res is not None for final_res in final_res_batch)

            final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)

            response = self.request_output_to_pooling_response(
                final_res_batch_checked,
                request_id,
                created_time,
                model_name,
                request.encoding_format,
                request.embed_dtype,
                request.endianness,
            )
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
        except ValueError as e:
            return self.create_error_response(e)

        return response

    def request_output_to_pooling_json_response(
        self,
        final_res_batch: list[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["float", "base64"],
        embed_dtype: EmbedDType,
        endianness: Endianness,
    ) -> PoolingResponse:
        encode_fn = cast(
            Callable[[PoolingRequestOutput], list[float] | str],
            (
                encode_pooling_output_float
                if encoding_format == "float"
                else partial(
                    encode_pooling_output_base64,
                    embed_dtype=embed_dtype,
                    endianness=endianness,
                )
            ),
        )

        items: list[PoolingResponseData] = []
        num_prompt_tokens = 0

        for idx, final_res in enumerate(final_res_batch):
            item = PoolingResponseData(
                index=idx,
                data=encode_fn(final_res),
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return PoolingResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            data=items,
            usage=usage,
        )

    def request_output_to_pooling_bytes_response(
        self,
        final_res_batch: list[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["bytes", "bytes_only"],
        embed_dtype: EmbedDType,
        endianness: Endianness,
    ) -> PoolingBytesResponse:
        content, items, usage = encode_pooling_bytes(
            pooling_outputs=final_res_batch,
            embed_dtype=embed_dtype,
            endianness=endianness,
        )

        headers = (
            None
            if encoding_format == "bytes_only"
            else {
                "metadata": json.dumps(
                    {
                        "id": request_id,
                        "created": created_time,
                        "model": model_name,
                        "data": items,
                        "usage": usage,
                    }
                )
            }
        )

        return PoolingBytesResponse(content=content, headers=headers)

    def request_output_to_pooling_response(
        self,
        final_res_batch: list[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: EncodingFormat,
        embed_dtype: EmbedDType,
        endianness: Endianness,
    ) -> PoolingResponse | PoolingBytesResponse:
        if encoding_format == "float" or encoding_format == "base64":
            return self.request_output_to_pooling_json_response(
                final_res_batch,
                request_id,
                created_time,
                model_name,
                encoding_format,
                embed_dtype,
                endianness,
            )

        if encoding_format == "bytes" or encoding_format == "bytes_only":
            return self.request_output_to_pooling_bytes_response(
                final_res_batch,
                request_id,
                created_time,
                model_name,
                encoding_format,
                embed_dtype,
                endianness,
            )

        assert_never(encoding_format)

chat_template instance-attribute

chat_template = chat_template

chat_template_content_format instance-attribute

chat_template_content_format: Final = (
    chat_template_content_format
)

trust_request_chat_template instance-attribute

trust_request_chat_template = trust_request_chat_template

__init__

__init__(
    engine_client: EngineClient,
    models: OpenAIServingModels,
    *,
    request_logger: RequestLogger | None,
    chat_template: str | None,
    chat_template_content_format: ChatTemplateContentFormatOption,
    trust_request_chat_template: bool = False,
    log_error_stack: bool = False,
) -> None
Source code in vllm/entrypoints/pooling/pooling/serving.py
def __init__(
    self,
    engine_client: EngineClient,
    models: OpenAIServingModels,
    *,
    request_logger: RequestLogger | None,
    chat_template: str | None,
    chat_template_content_format: ChatTemplateContentFormatOption,
    trust_request_chat_template: bool = False,
    log_error_stack: bool = False,
) -> None:
    super().__init__(
        engine_client=engine_client,
        models=models,
        request_logger=request_logger,
        log_error_stack=log_error_stack,
    )

    self.chat_template = chat_template
    self.chat_template_content_format: Final = chat_template_content_format
    self.trust_request_chat_template = trust_request_chat_template

create_pooling async

create_pooling(
    request: PoolingRequest,
    raw_request: Request | None = None,
) -> (
    PoolingResponse
    | IOProcessorResponse
    | PoolingBytesResponse
    | ErrorResponse
)

See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API.

Source code in vllm/entrypoints/pooling/pooling/serving.py
async def create_pooling(
    self,
    request: PoolingRequest,
    raw_request: Request | None = None,
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
    """
    See https://platform.openai.com/docs/api-reference/embeddings/create
    for the API specification. This API mimics the OpenAI Embedding API.
    """
    error_check_ret = await self._check_model(request)
    if error_check_ret is not None:
        return error_check_ret

    model_name = self.models.model_name()

    request_id = f"pool-{self._base_request_id(raw_request)}"
    created_time = int(time.time())

    is_io_processor_request = isinstance(request, IOProcessorRequest)
    try:
        lora_request = self._maybe_get_adapters(request)

        if getattr(request, "dimensions", None) is not None:
            return self.create_error_response(
                "dimensions is currently not supported"
            )

        engine_prompts: Sequence[PromptType | TokPrompt]
        if is_io_processor_request:
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details."
                )

            validated_prompt = self.io_processor.parse_request(request)

            raw_prompts = await self.io_processor.pre_process_async(
                prompt=validated_prompt, request_id=request_id
            )
            engine_prompts = prompt_to_seq(raw_prompts)
        elif isinstance(request, PoolingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

            _, engine_prompts = await self._preprocess_chat(
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=None,
            )
        elif isinstance(request, PoolingCompletionRequest):
            engine_prompts = await self._preprocess_completion(
                request,
                prompt_input=request.input,
                prompt_embeds=None,
            )
        else:
            raise ValueError(f"Unsupported request of type {type(request)}")
    except (ValueError, TypeError, jinja2.TemplateError) as e:
        logger.exception("Error in preprocessing prompt inputs")
        return self.create_error_response(str(e))

    # Schedule the request and get the result generator.
    generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
    try:
        if is_io_processor_request:
            assert self.io_processor is not None and isinstance(
                request, IOProcessorRequest
            )
            pooling_params = self.io_processor.validate_or_generate_params()
        else:
            pooling_params = request.to_pooling_params()

        for i, engine_prompt in enumerate(engine_prompts):
            request_id_item = f"{request_id}-{i}"

            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=pooling_params,
                lora_request=lora_request,
            )

            trace_headers = (
                None
                if raw_request is None
                else await self._get_trace_headers(raw_request.headers)
            )

            if is_io_processor_request:
                tokenization_kwargs: dict[str, Any] = {}
            else:
                tok_params = request.build_tok_params(self.model_config)  # type: ignore
                tokenization_kwargs = tok_params.get_encode_kwargs()

            generator = self.engine_client.encode(
                engine_prompt,
                pooling_params,
                request_id_item,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=request.priority,
            )

            generators.append(generator)
    except ValueError as e:
        return self.create_error_response(e)

    result_generator = merge_async_iterators(*generators)

    if is_io_processor_request:
        assert self.io_processor is not None
        output = await self.io_processor.post_process_async(
            model_output=result_generator,
            request_id=request_id,
        )
        return self.io_processor.output_to_response(output)

    assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
    num_prompts = len(engine_prompts)

    # Non-streaming response
    final_res_batch: list[PoolingRequestOutput | None]
    final_res_batch = [None] * num_prompts
    try:
        async for i, res in result_generator:
            final_res_batch[i] = res

        assert all(final_res is not None for final_res in final_res_batch)

        final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)

        response = self.request_output_to_pooling_response(
            final_res_batch_checked,
            request_id,
            created_time,
            model_name,
            request.encoding_format,
            request.embed_dtype,
            request.endianness,
        )
    except asyncio.CancelledError:
        return self.create_error_response("Client disconnected")
    except ValueError as e:
        return self.create_error_response(e)

    return response

request_output_to_pooling_bytes_response

request_output_to_pooling_bytes_response(
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: Literal["bytes", "bytes_only"],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingBytesResponse
Source code in vllm/entrypoints/pooling/pooling/serving.py
def request_output_to_pooling_bytes_response(
    self,
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: Literal["bytes", "bytes_only"],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingBytesResponse:
    content, items, usage = encode_pooling_bytes(
        pooling_outputs=final_res_batch,
        embed_dtype=embed_dtype,
        endianness=endianness,
    )

    headers = (
        None
        if encoding_format == "bytes_only"
        else {
            "metadata": json.dumps(
                {
                    "id": request_id,
                    "created": created_time,
                    "model": model_name,
                    "data": items,
                    "usage": usage,
                }
            )
        }
    )

    return PoolingBytesResponse(content=content, headers=headers)

request_output_to_pooling_json_response

request_output_to_pooling_json_response(
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: Literal["float", "base64"],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingResponse
Source code in vllm/entrypoints/pooling/pooling/serving.py
def request_output_to_pooling_json_response(
    self,
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: Literal["float", "base64"],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingResponse:
    encode_fn = cast(
        Callable[[PoolingRequestOutput], list[float] | str],
        (
            encode_pooling_output_float
            if encoding_format == "float"
            else partial(
                encode_pooling_output_base64,
                embed_dtype=embed_dtype,
                endianness=endianness,
            )
        ),
    )

    items: list[PoolingResponseData] = []
    num_prompt_tokens = 0

    for idx, final_res in enumerate(final_res_batch):
        item = PoolingResponseData(
            index=idx,
            data=encode_fn(final_res),
        )
        prompt_token_ids = final_res.prompt_token_ids

        items.append(item)
        num_prompt_tokens += len(prompt_token_ids)

    usage = UsageInfo(
        prompt_tokens=num_prompt_tokens,
        total_tokens=num_prompt_tokens,
    )

    return PoolingResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        data=items,
        usage=usage,
    )

request_output_to_pooling_response

request_output_to_pooling_response(
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: EncodingFormat,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse
Source code in vllm/entrypoints/pooling/pooling/serving.py
def request_output_to_pooling_response(
    self,
    final_res_batch: list[PoolingRequestOutput],
    request_id: str,
    created_time: int,
    model_name: str,
    encoding_format: EncodingFormat,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
    if encoding_format == "float" or encoding_format == "base64":
        return self.request_output_to_pooling_json_response(
            final_res_batch,
            request_id,
            created_time,
            model_name,
            encoding_format,
            embed_dtype,
            endianness,
        )

    if encoding_format == "bytes" or encoding_format == "bytes_only":
        return self.request_output_to_pooling_bytes_response(
            final_res_batch,
            request_id,
            created_time,
            model_name,
            encoding_format,
            embed_dtype,
            endianness,
        )

    assert_never(encoding_format)