# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from distilabel.llms import OpenAILLM from distilabel.pipeline import Pipeline from distilabel.steps.tasks import TextGeneration def build_distilabel_pipeline( model: str, base_url: str = "http://localhost:8000/v1", prompt_column: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_new_tokens: int = 8192, num_generations: int = 1, ) -> Pipeline: generation_kwargs = {"max_new_tokens": max_new_tokens} if temperature is not None: generation_kwargs["temperature"] = temperature if top_p is not None: generation_kwargs["top_p"] = top_p with Pipeline().ray() as pipeline: TextGeneration( llm=OpenAILLM( base_url=base_url, api_key="something", model=model, # thinking can take some time... timeout=10 * 60, generation_kwargs=generation_kwargs, ), input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion num_generations=num_generations, ) return pipeline if __name__ == "__main__": import argparse from datasets import load_dataset parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") parser.add_argument( "--hf-dataset", type=str, required=True, help="HuggingFace dataset to load", ) parser.add_argument( "--hf-dataset-config", type=str, required=False, help="Dataset config to use", ) parser.add_argument( "--hf-dataset-split", type=str, default="train", help="Dataset split to use", ) parser.add_argument("--prompt-column", type=str, default="prompt") parser.add_argument( "--model", type=str, required=True, help="Model name to use for generation", ) parser.add_argument( "--vllm-server-url", type=str, default="http://localhost:8000/v1", help="URL of the vLLM server", ) parser.add_argument( "--temperature", type=float, help="Temperature for generation", ) parser.add_argument( "--top-p", type=float, help="Top-p value for generation", ) parser.add_argument( "--max-new-tokens", type=int, default=8192, help="Maximum number of new tokens to generate", ) parser.add_argument( "--num-generations", type=int, default=1, help="Number of generations per problem", ) parser.add_argument( "--hf-output-dataset", type=str, required=False, help="HuggingFace repo to push results to", ) parser.add_argument( "--private", action="store_true", help="Whether to make the output dataset private when pushing to HF Hub", ) args = parser.parse_args() print("\nRunning with arguments:") for arg, value in vars(args).items(): print(f" {arg}: {value}") print() print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) print("Dataset loaded!") pipeline = build_distilabel_pipeline( model=args.model, base_url=args.vllm_server_url, prompt_column=args.prompt_column, temperature=args.temperature, top_p=args.top_p, max_new_tokens=args.max_new_tokens, num_generations=args.num_generations, ) print("Running generation pipeline...") distiset = pipeline.run(dataset=dataset, use_cache=False) print("Generation pipeline finished!") if args.hf_output_dataset: print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") distiset.push_to_hub(args.hf_output_dataset, private=args.private) print("Dataset pushed!")