157 lines
4.7 KiB
Python
Executable File
157 lines
4.7 KiB
Python
Executable File
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Optional
|
|
|
|
from distilabel.llms import OpenAILLM
|
|
from distilabel.pipeline import Pipeline
|
|
from distilabel.steps.tasks import TextGeneration
|
|
|
|
|
|
def build_distilabel_pipeline(
|
|
model: str,
|
|
base_url: str = "http://localhost:8000/v1",
|
|
prompt_column: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
max_new_tokens: int = 8192,
|
|
num_generations: int = 1,
|
|
) -> Pipeline:
|
|
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
|
|
|
if temperature is not None:
|
|
generation_kwargs["temperature"] = temperature
|
|
|
|
if top_p is not None:
|
|
generation_kwargs["top_p"] = top_p
|
|
|
|
with Pipeline().ray() as pipeline:
|
|
TextGeneration(
|
|
llm=OpenAILLM(
|
|
base_url=base_url,
|
|
api_key="something",
|
|
model=model,
|
|
# thinking can take some time...
|
|
timeout=10 * 60,
|
|
generation_kwargs=generation_kwargs,
|
|
),
|
|
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
|
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
|
|
num_generations=num_generations,
|
|
)
|
|
|
|
return pipeline
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
from datasets import load_dataset
|
|
|
|
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
|
parser.add_argument(
|
|
"--hf-dataset",
|
|
type=str,
|
|
required=True,
|
|
help="HuggingFace dataset to load",
|
|
)
|
|
parser.add_argument(
|
|
"--hf-dataset-config",
|
|
type=str,
|
|
required=False,
|
|
help="Dataset config to use",
|
|
)
|
|
parser.add_argument(
|
|
"--hf-dataset-split",
|
|
type=str,
|
|
default="train",
|
|
help="Dataset split to use",
|
|
)
|
|
parser.add_argument("--prompt-column", type=str, default="prompt")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
required=True,
|
|
help="Model name to use for generation",
|
|
)
|
|
parser.add_argument(
|
|
"--vllm-server-url",
|
|
type=str,
|
|
default="http://localhost:8000/v1",
|
|
help="URL of the vLLM server",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
help="Temperature for generation",
|
|
)
|
|
parser.add_argument(
|
|
"--top-p",
|
|
type=float,
|
|
help="Top-p value for generation",
|
|
)
|
|
parser.add_argument(
|
|
"--max-new-tokens",
|
|
type=int,
|
|
default=8192,
|
|
help="Maximum number of new tokens to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--num-generations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of generations per problem",
|
|
)
|
|
parser.add_argument(
|
|
"--hf-output-dataset",
|
|
type=str,
|
|
required=False,
|
|
help="HuggingFace repo to push results to",
|
|
)
|
|
parser.add_argument(
|
|
"--private",
|
|
action="store_true",
|
|
help="Whether to make the output dataset private when pushing to HF Hub",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("\nRunning with arguments:")
|
|
for arg, value in vars(args).items():
|
|
print(f" {arg}: {value}")
|
|
print()
|
|
|
|
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
|
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
|
|
print("Dataset loaded!")
|
|
|
|
pipeline = build_distilabel_pipeline(
|
|
model=args.model,
|
|
base_url=args.vllm_server_url,
|
|
prompt_column=args.prompt_column,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
max_new_tokens=args.max_new_tokens,
|
|
num_generations=args.num_generations,
|
|
)
|
|
|
|
print("Running generation pipeline...")
|
|
distiset = pipeline.run(dataset=dataset, use_cache=False)
|
|
print("Generation pipeline finished!")
|
|
|
|
if args.hf_output_dataset:
|
|
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
|
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
|
print("Dataset pushed!")
|