from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd
#from tqdm import tqdm


pipeline_cache = '/root/.cache/distilabel/pipelines/distill-qwen-32b-r1-tomi'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)


prompt_template = """\
You will be given a story, a question, and the corresponding choices. Please reason step by step to answer this question, and put your final answer within \\boxed{}:

Story: {{ story }}
Question: {{question}}
Choices: {{containers}}
"""


dataset = load_dataset("json", data_files=".../test_balanced.json", split="train[:800]")
# a = dataset[0]


def add_combined_column(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["containers"]
        if isinstance(choices_text, list):
            choices_text = ", ".join(choices_text)
            
        # Create combined text
        example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset = add_combined_column(dataset)
print(dataset)
print(dataset[0])



model_id = ".../distill32B"

with Pipeline(
    name="distill-qwen-32b-r1-tomi",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 8192,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 8192,
        },
    )


    #prompt_column = "problem"
    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=2,
        input_batch_size=4,
        columns = ["story", "question", "containers"],
    )

    
    format_sft = FormatTextGenerationSFT(input_mappings={"instruction": "entire_instruction"})

    text_generation.connect(format_sft)



if __name__ == "__main__":
    distiset = pipeline.run(dataset=dataset)
    print(distiset)
    print(distiset['default']['train'][0]) 
    distiset.save_to_disk(".../SFTData/ToMi_May_test")
    distiset.load_from_disk(".../SFTData/ToMi_May_test")
    print(distiset)
    