import os
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset, concatenate_datasets
import argparse
import glob
import random

MODEL_CONFIGS = {
    "LFM2-350M": "LiquidAI/LFM2-350M",
    "LFM2-700M": "LiquidAI/LFM2-700M", 
    "LFM2-1.2B": "LiquidAI/LFM2-1.2B",
    "SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
    "SmolLM2-360M-Instruct": "HuggingFaceTB/SmolLM2-360M-Instruct",
    "SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct"
}

ALGORITHM_ORDER = [
    "01_adaptive_context_learner",
    "02_random_walk_wanderer",
    "03_breadth_first_explorer", 
    "04_depth_first_driller",
    "05_wrong_direction_specialist",
    "06_early_success_validator",
    "07_exploitation_heavy_validator",
    "08_greedy_hill_climber",
    "09_best_first_hypothesis_selector",
    "10_multi_beam_parallel"
]

def create_dataset(csv_path, tokenizer, max_length=8192):
    df = pd.read_csv(csv_path)
    texts = df['text'].tolist()

    def tokenize_function(examples):
        enc = tokenizer(
            examples['text'],
            truncation=True,
            padding="max_length",   # <— ensure fixed length everywhere
            max_length=max_length,
            return_offsets_mapping=True,  # helps robust masking (optional here)
        )

        labels = []
        for i, text in enumerate(examples['text']):
            lab = enc['input_ids'][i].copy()

            # robust token-span masking using offsets rather than ratios
            # find char spans first
            def mask_span(tag_open, tag_close):
                start = text.find(tag_open)
                end = text.find(tag_close)
                if start != -1 and end != -1:
                    return (start, end + len(tag_close))
                return None

            spans = []
            for o, c in [("<user_query>", "</user_query>"),
                         ("<top_k_response>", "</top_k_response>")]:
                s = mask_span(o, c)
                if s:
                    spans.append(s)

            # map char spans to token spans via offsets
            offs = enc['offset_mapping'][i]
            for (s_ch, e_ch) in spans:
                for tidx, (a, b) in enumerate(offs):
                    if a == b:   # special/padding tokens
                        continue
                    if not (b <= s_ch or a >= e_ch):  # overlap
                        lab[tidx] = -100

            labels.append(lab)

        enc.pop('offset_mapping', None)  # not needed by the model
        enc['labels'] = labels
        return enc

    dataset = Dataset.from_dict({'text': texts})
    return dataset.map(tokenize_function, batched=True, remove_columns=['text'])

def train_random_shuffle(model_name, gpu_id, data_dir="./data/sft", learning_rate=5e-5):
    # GPU visibility handled by shell script via CUDA_VISIBLE_DEVICES
    
    model_path = MODEL_CONFIGS[model_name]
    base_output_dir = f"./models/random_shuffle_baseline/{model_name}"
    os.makedirs(base_output_dir, exist_ok=True)
    
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    special_tokens = {
        "additional_special_tokens": [
            "<user_query>", "</user_query>",
            "<think>", "</think>", 
            "<search_query>", "</search_query>",
            "<top_k_response>", "</top_k_response>"
        ]
    }
    tokenizer.add_special_tokens(special_tokens)
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    model.resize_token_embeddings(len(tokenizer))
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    
    # Load and combine all CSV files
    all_datasets = []
    for algorithm in ALGORITHM_ORDER:
        csv_file = f"{data_dir}/algorithm_{algorithm}.csv"
        if os.path.exists(csv_file):
            dataset = create_dataset(csv_file, tokenizer)
            all_datasets.append(dataset)
            print(f"Loaded {len(dataset)} samples from {algorithm}")
        else:
            print(f"Warning: {csv_file} not found, skipping")
    
    if not all_datasets:
        raise ValueError("No datasets found!")
    
    # Combine all datasets
    combined_dataset = concatenate_datasets(all_datasets)
    print(f"Combined dataset size: {len(combined_dataset)}")
    
    # Shuffle the combined dataset
    shuffled_dataset = combined_dataset.shuffle(seed=42)
    print("Dataset shuffled randomly")
    
    training_args = TrainingArguments(
        output_dir=base_output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,  # Same total epochs as curriculum
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=learning_rate,
        weight_decay=0.01,
        logging_dir=f"{base_output_dir}/logs",
        logging_steps=50,
        save_steps=2500,
        save_strategy="epoch",
        prediction_loss_only=True,
        remove_unused_columns=False,
        dataloader_drop_last=True,
        bf16=True,
        gradient_checkpointing=True,
        warmup_steps=0,
        lr_scheduler_type="constant",
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=shuffled_dataset,
    )
    
    print(f"Starting random shuffle training for {model_name}")
    trainer.train()
    trainer.save_model(base_output_dir)
    tokenizer.save_pretrained(base_output_dir)
    
    print(f"Saved random shuffle baseline model to {base_output_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True, choices=list(MODEL_CONFIGS.keys()))
    parser.add_argument("--gpu", type=int, required=True)  
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--data_dir", type=str, default="./data/sft")
    
    args = parser.parse_args()
    
    train_random_shuffle(args.model, args.gpu, args.data_dir, args.lr)