import os
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset

ALGORITHM_ORDER = [
    "01_adaptive_context_learner",
    "02_random_walk_wanderer", 
    "03_breadth_first_explorer",
    "04_depth_first_driller",
    "05_wrong_direction_specialist",
    "06_early_success_validator",
    "07_exploitation_heavy_validator",
    "08_greedy_hill_climber",
    "09_best_first_hypothesis_selector",
    "10_multi_beam_parallel"
]

def create_dataset(csv_path, tokenizer, max_length=8192):
    df = pd.read_csv(csv_path)
    texts = df['text'].tolist()

    def tokenize_function(examples):
        enc = tokenizer(
            examples['text'],
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_offsets_mapping=True,
        )

        labels = []
        for i, text in enumerate(examples['text']):
            lab = enc['input_ids'][i].copy()

            def mask_span(tag_open, tag_close):
                start = text.find(tag_open)
                end = text.find(tag_close)
                if start != -1 and end != -1:
                    return (start, end + len(tag_close))
                return None

            spans = []
            for o, c in [("<user_query>", "</user_query>"),
                         ("<top_k_response>", "</top_k_response>")]:
                s = mask_span(o, c)
                if s:
                    spans.append(s)

            offs = enc['offset_mapping'][i]
            for (s_ch, e_ch) in spans:
                for tidx, (a, b) in enumerate(offs):
                    if a == b:
                        continue
                    if not (b <= s_ch or a >= e_ch):
                        lab[tidx] = -100

            labels.append(lab)

        enc.pop('offset_mapping', None)
        enc['labels'] = labels
        return enc

    dataset = Dataset.from_dict({'text': texts})
    return dataset.map(tokenize_function, batched=True, remove_columns=['text'])

checkpoint_path = "./models/hard_curriculum_learning/SmolLM2-360M-Instruct/epoch-2/algorithm-08"
base_output_dir = "./models/hard_curriculum_learning/SmolLM2-360M-Instruct"
epoch_dir = f"{base_output_dir}/epoch-2"
data_dir = "./data/sft"
learning_rate = 5e-5

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

for algo_idx in [8, 9]:
    algorithm = ALGORITHM_ORDER[algo_idx]
    csv_file = f"{data_dir}/algorithm_{algorithm}.csv"
    
    dataset = create_dataset(csv_file, tokenizer)
    
    algo_output_dir = f"{epoch_dir}/algorithm-{algo_idx+1:02d}"
    os.makedirs(algo_output_dir, exist_ok=True)
    
    training_args = TrainingArguments(
        output_dir=algo_output_dir,
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=learning_rate,
        weight_decay=0.01,
        logging_dir=f"{algo_output_dir}/logs",
        logging_steps=50,
        save_steps=2500,
        save_strategy="epoch",
        prediction_loss_only=True,
        remove_unused_columns=False,
        dataloader_drop_last=True,
        bf16=True,
        gradient_checkpointing=True,
        warmup_steps=0,
        lr_scheduler_type="constant",
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset,
    )
    
    trainer.train()
    trainer.save_model(algo_output_dir)
    tokenizer.save_pretrained(algo_output_dir)