import os
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
import argparse
import glob

MODEL_CONFIGS = {
    "LFM2-350M": "LiquidAI/LFM2-350M",
    "LFM2-700M": "LiquidAI/LFM2-700M", 
    "LFM2-1.2B": "LiquidAI/LFM2-1.2B",
    "SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
    "SmolLM2-360M-Instruct": "HuggingFaceTB/SmolLM2-360M-Instruct",
    "SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct"
}

ALGORITHM_ORDER = [
    "01_adaptive_context_learner",
    "02_random_walk_wanderer",
    "03_breadth_first_explorer", 
    "04_depth_first_driller",
    "05_wrong_direction_specialist",
    "06_early_success_validator",
    "07_exploitation_heavy_validator",
    "08_greedy_hill_climber",
    "09_best_first_hypothesis_selector",
    "10_multi_beam_parallel"
]

def create_dataset(csv_path, tokenizer, max_length=8192):
    df = pd.read_csv(csv_path)
    texts = df['text'].tolist()

    def tokenize_function(examples):
        enc = tokenizer(
            examples['text'],
            truncation=True,
            padding="max_length",   # <— ensure fixed length everywhere
            max_length=max_length,
            return_offsets_mapping=True,  # helps robust masking (optional here)
        )

        labels = []
        for i, text in enumerate(examples['text']):
            lab = enc['input_ids'][i].copy()

            # robust token-span masking using offsets rather than ratios
            # find char spans first
            def mask_span(tag_open, tag_close):
                start = text.find(tag_open)
                end = text.find(tag_close)
                if start != -1 and end != -1:
                    return (start, end + len(tag_close))
                return None

            spans = []
            for o, c in [("<user_query>", "</user_query>"),
                         ("<top_k_response>", "</top_k_response>")]:
                s = mask_span(o, c)
                if s:
                    spans.append(s)

            # map char spans to token spans via offsets
            offs = enc['offset_mapping'][i]
            for (s_ch, e_ch) in spans:
                for tidx, (a, b) in enumerate(offs):
                    if a == b:   # special/padding tokens
                        continue
                    if not (b <= s_ch or a >= e_ch):  # overlap
                        lab[tidx] = -100

            labels.append(lab)

        enc.pop('offset_mapping', None)  # not needed by the model
        enc['labels'] = labels
        return enc

    dataset = Dataset.from_dict({'text': texts})
    return dataset.map(tokenize_function, batched=True, remove_columns=['text'])

def train_curriculum(model_name, gpu_id, data_dir="./data/sft", learning_rate=5e-5):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    
    model_path = MODEL_CONFIGS[model_name]
    base_output_dir = f"./models/hard_curriculum_learning/{model_name}"
    os.makedirs(base_output_dir, exist_ok=True)
    
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    special_tokens = {
        "additional_special_tokens": [
            "<user_query>", "</user_query>",
            "<think>", "</think>", 
            "<search_query>", "</search_query>",
            "<top_k_response>", "</top_k_response>"
        ]
    }
    tokenizer.add_special_tokens(special_tokens)
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    model.resize_token_embeddings(len(tokenizer))
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    
    for epoch in range(1, 4):  # 3 epochs total
        epoch_dir = f"{base_output_dir}/epoch-{epoch}"
        os.makedirs(epoch_dir, exist_ok=True)
        
        for algo_idx, algorithm in enumerate(ALGORITHM_ORDER, 1):
            csv_file = f"{data_dir}/algorithm_{algorithm}.csv"
            
            if not os.path.exists(csv_file):
                print(f"Warning: {csv_file} not found, skipping")
                continue
            
            print(f"Training {model_name} - Epoch {epoch} - Algorithm {algo_idx:02d}: {algorithm}")
            
            dataset = create_dataset(csv_file, tokenizer)
            
            algo_output_dir = f"{epoch_dir}/algorithm-{algo_idx:02d}"
            
            training_args = TrainingArguments(
                output_dir=algo_output_dir,
                overwrite_output_dir=True,
                num_train_epochs=1,
                per_device_train_batch_size=2,
                gradient_accumulation_steps=2,
                learning_rate=learning_rate,
                weight_decay=0.01,
                logging_dir=f"{algo_output_dir}/logs",
                logging_steps=50,
                save_steps=2500,
                save_strategy="epoch",
                prediction_loss_only=True,
                remove_unused_columns=False,
                dataloader_drop_last=True,
                bf16=True,
                gradient_checkpointing=True,
                warmup_steps=0,
                lr_scheduler_type="constant",
            )
            
            trainer = Trainer(
                model=model,
                args=training_args,
                data_collator=data_collator,
                train_dataset=dataset,
            )
            
            trainer.train()
            trainer.save_model(algo_output_dir)
            tokenizer.save_pretrained(algo_output_dir)
            
            print(f"Saved {model_name} after algorithm {algo_idx:02d} to {algo_output_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True, choices=list(MODEL_CONFIGS.keys()))
    parser.add_argument("--gpu", type=int, required=True)  
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--data_dir", type=str, default="./data/sft")
    
    args = parser.parse_args()
    
    train_curriculum(args.model, args.gpu, args.data_dir, args.lr)