import json
import argparse
import logging
import os
from datasets import load_dataset, Dataset, DatasetDict
from typing import Optional, Dict, Union, List
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, TrainerCallback
from trl import SFTTrainer,SFTConfig
import torch
import random
import numpy as np
from peft import LoraConfig, get_peft_model
import wandb

from common import build_validation_from_cfg

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class PerBucketEvalLoggerCallback(TrainerCallback):
    def __init__(
        self,
        model,
        tokenizer,
        val_bucketed_indices: Dict[int, List[int]],
        val_dataset,  # HF Dataset
        verifier_type="math",  # "math" or "nl" or "exact"
        val_steps: int = 50,
        val_questions: int = 50,
        gen_eval: Optional[object] = None,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.val_bucketed_indices = val_bucketed_indices
        self.val_dataset = val_dataset
        self.val_steps = max(1, int(val_steps))
        self.val_questions = max(1, int(val_questions))
        self.gen_eval = gen_eval
        self.verifier_type = verifier_type

    @torch.no_grad()
    def _run_custom_validation(self) -> Dict[int, float]:
        if self.gen_eval is None:
            raise ValueError("GenerationEvaluator (gen_eval) is not provided for validation.")
        
        self.model.eval() # Set the model to evaluation mode
        accuracies = {}

        try: # Use a try...finally block to ensure the cache is cleared
            for bucket_id, indices in self.val_bucketed_indices.items():
                if not indices:
                    accuracies[bucket_id] = 0.0
                    continue

                # sample up to eval_questions from this bucket
                sample_indices = random.sample(indices, min(self.val_questions, len(indices)))
                acc = self.gen_eval.evaluate_indices(
                    dataset = self.val_dataset,
                    indices = sample_indices,
                    verify_mode = self.verifier_type,
                )
                accuracies[bucket_id] = acc
                logging.info(f"  Bucket {bucket_id}: {len(sample_indices)} samples, accuracy = {acc:.2f}%")
        finally:
            # free cache and return to train mode
            torch.cuda.empty_cache()
            logging.debug("Cleared CUDA cache after validation.")
            self.model.train()
        return accuracies

    def on_step_end(self, args, state, control, **kwargs):
        # run at step multiples of val_steps (not at step 0)
        if state.global_step > 0 and (state.global_step + 1) % self.val_steps == 0:
            # guard: only log once from rank 0
            if state.is_world_process_zero:
                logging.info("\n--- Running per-bucket validation...")
                accs = self._run_custom_validation()
                log_data = {}
                for bkt, a in accs.items():
                    log_data[f"validation/accuracy_bucket_{bkt}"] = a
                log_data["validation/accuracy_overall"] = sum(accs.values()) / len(accs) if accs else 0.0

                try:
                    wandb.log(log_data)
                except Exception as e:
                    logging.warning(f"wandb.log failed: {e}")
        return control


def train(config):
    # ---- Load data ----
    dataset = load_dataset("json", data_files={
        "train": config["dataset"]["trainset_path"],
        "validation": config["dataset"]["valset_path"],
    })
    # cache_dir=os.environ.get("HF_DATASETS_CACHE"),  # local tmpfs/SSD
    # keep_in_memory=True)
        
    # dataset = dataset.shuffle(seed=config["dataset"]["seed"])
    
    # Create bucket indices of each difficulty level
    n_buckets = max(item['difficulty'] for item in dataset['train']) + 1
    bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['train']):
        bucket_indices[item['difficulty']].append(i)
    val_bucket_indices = {i: [] for i in range(n_buckets)}
    for i, item in enumerate(dataset['validation']):
        val_bucket_indices[item['difficulty']].append(i)

    logging.debug(f"bucket_indices: {bucket_indices}")
    logging.debug(f"val_bucket_indices: {val_bucket_indices}")
 
    # ---- Load model and tokenizer ----
    logging.info("--- Initializing model & tokenizer ---")
    student_model = AutoModelForCausalLM.from_pretrained(
        config["models"]["student"],
        trust_remote_code=True
    )
    student_tokenizer = AutoTokenizer.from_pretrained(
        config["models"]["student"], 
        trust_remote_code=True
    )
    student_tokenizer.padding_side = "left"
    if student_tokenizer.pad_token is None:
        student_tokenizer.pad_token = student_tokenizer.eos_token

    if "peft" in config:
        peft_config = LoraConfig(**config["peft"])
        student_peft_model = get_peft_model(student_model, peft_config)
        student_peft_model.print_trainable_parameters()
        model_for_trainer = student_peft_model
    else:
        peft_config = None
        model_for_trainer = student_model

    # Get the training arguments
    training_arguments = SFTConfig(**config["training"])
    # if not getattr(training_arguments, "max_steps", 0):
    #     raise ValueError("This kd/sft script is step-based. Please set training.max_steps > 0 in your config.")

     # Get the validation arguments
    validation_arguments = config.get("validation", {})
    # decide verify_mode during validation
    verify_mode = 'math' if 'gsm' in config["dataset"]["name"].lower() else 'nl'
    
    mode, sbert_backend, verifier, gen_eval = build_validation_from_cfg(
        model=student_model,
        tokenizer=student_tokenizer,
        val_cfg=validation_arguments,
        default_mode=verify_mode,
    )
    
    # ---- Wire up per-bucket eval logger ----
    eval_cb = PerBucketEvalLoggerCallback(
        model=model_for_trainer,
        tokenizer=student_tokenizer,
        val_bucketed_indices=val_bucket_indices,
        val_dataset=dataset["validation"],
        val_steps=config["validation"].get("val_steps", 50),
        val_questions=config["validation"].get("val_questions", 50),
        gen_eval=gen_eval,
        verifier_type=verify_mode
    )

    try:
        job_type = config["job_type"]
        if 'kd' in job_type or 'sft' in job_type:
            trainer = SFTTrainer(
                model=student_model,
                processing_class=student_tokenizer,
                args=training_arguments,
                train_dataset=dataset["train"],
                peft_config=peft_config,
                callbacks=[eval_cb],
            )
        else:
            logging.error(f"Invalid job type: {job_type}")
            raise ValueError(f"Invalid job type: {job_type}")
    except ValueError as e:
        logging.error(f"Training job terminated: {e}")
        return
        
    trainer.train()
    trainer.save_model(config["training"]["output_dir"])
    student_tokenizer.save_pretrained(config["training"]["output_dir"])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))
    
    # Create a single run name string
    from datetime import datetime
    run_name = f"{config['dataset']['name']}_{config['models']['student'].split('/')[-1]}_{os.path.basename(args.config).split('.')[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    # inject run_name into training config (used by SFTConfig/TrainingArguments)
    config["training"]["run_name"] = run_name
    
    train(config)


if __name__ == "__main__":
    main()