import json
import logging
import os
import random
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import wandb
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainerCallback,
    TrainingArguments,
    default_data_collator,
    set_seed,
)


def is_rank_0():
    return not dist.is_initialized() or dist.get_rank() == 0


@dataclass
class ModelArguments:
    model_name_or_path: str = field(metadata={"help": "HF model path or ID"})
    cache_dir: Optional[str] = field(default=None)
    model_revision: str = field(default="main")
    torch_dtype: Optional[str] = field(default=None)
    trust_remote_code: bool = field(default=True)
    low_cpu_mem_usage: bool = field(default=True)


@dataclass
class RetainArguments:
    num_train_steps: int = field(default=10000)
    output_comparison_file: Optional[str] = field(default="comparison_outputs.jsonl")
    retain_train_data_file: Optional[str] = field(default=None, metadata={"help": "mathqa train file path"})
    retain_eval_data_file: Optional[str] = field(default=None, metadata={"help": "mathqa validation file path"})
    retain_eval_max_samples: int = field(default=30, metadata={"help": "Number of samples used for retain_eval evaluation"})
    wandb_project: Optional[str] = field(default="retain_only_ft_mathqa")

class QADataset(Dataset):
    def __init__(self, data_file, tokenizer, max_samples=None, seed=42):
        self.samples = []

        with open(data_file, 'r', encoding='utf-8') as f:
            data = json.load(f) 

        for sample in data:
            question = sample["Problem"]
            correct_option_letter = sample["correct"].strip().lower()
            options_text = sample["options"]

            option_map = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4}
            correct_index = option_map.get(correct_option_letter)

            if correct_index is None:
                continue  

            option_items = [opt.strip() for opt in options_text.split(',')]
            if correct_index >= len(option_items):
                continue  

            try:
                correct_text = option_items[correct_index].split(')')[1].strip()
            except IndexError:
                continue 

            prompt = f"Question: {question}\nAnswer:"
            completion = f" {correct_text}"
            self.samples.append((prompt, completion))

        if max_samples is not None:
            rng = random.Random(seed)
            self.samples = rng.sample(self.samples, min(max_samples, len(self.samples)))

        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        prompt, completion = self.samples[idx]
        text = prompt + completion

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=512,
        )
        encoding["labels"] = encoding["input_ids"].copy()
        return encoding


def evaluate_retained_qa(model, tokenizer, dataset, prefix="retain/"):
    model.eval()
    total_loss = 0.0
    num_batches = 0


    data_collator = default_data_collator
    dataloader = DataLoader(dataset, batch_size=1, collate_fn=data_collator)

    for batch in tqdm(dataloader, desc=f"Evaluating {prefix}"):
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["input_ids"]
            )
            loss = outputs.loss
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else float("inf")
    wandb.log({f"{prefix}eval_loss": avg_loss})
    print(f"{prefix}Eval Loss: {avg_loss:.4f}")
    return avg_loss

class EvalCallback(TrainerCallback):
    def __init__(self, model, tokenizer, retain_dataset=None):
        self.model = model
        self.tokenizer = tokenizer
        self.retain_dataset = retain_dataset

    def on_log(self, args, state, control, **kwargs):
        if self.retain_dataset is not None:
            evaluate_retained_qa(self.model, self.tokenizer, self.retain_dataset, prefix="train/retain/")

def main():
    parser = HfArgumentParser((ModelArguments, RetainArguments, TrainingArguments))
    model_args, retain_args, training_args = parser.parse_args_into_dataclasses()

    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(training_args.output_dir, f"run-{run_time}")
    training_args.output_dir = run_dir

    if is_rank_0():
        os.makedirs(run_dir, exist_ok=True)
        wandb.init(project=retain_args.wandb_project, name=training_args.run_name)

    logging.basicConfig(level=logging.INFO)
    set_seed(training_args.seed)
    torch.manual_seed(training_args.seed)
    random.seed(training_args.seed)
    np.random.seed(training_args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    if "gpt2" in model_args.model_name_or_path.lower():
        from transformers import GPT2Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype else None
    # torch_dtype = torch.bfloat16

    if "pythia" in model_args.model_name_or_path.lower():
        from transformers import GPTNeoXForCausalLM
        model = GPTNeoXForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    elif "gpt2" in model_args.model_name_or_path.lower():
        from transformers import GPT2LMHeadModel
        model = GPT2LMHeadModel.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )

    retain_train_dataset = None
    if retain_args.retain_train_data_file:
        retain_train_dataset = QADataset(
            data_file=retain_args.retain_train_data_file,
            tokenizer=tokenizer,
            max_samples=None,
            seed=training_args.seed
        )
        for i in range(10):
            item = retain_train_dataset[i]
            print(f"{i}: input_ids len = {len(item['input_ids'])}, labels len = {len(item['labels'])}")

    retain_eval_dataset = None
    if retain_args.retain_eval_data_file:
        retain_eval_dataset = QADataset(
            data_file=retain_args.retain_eval_data_file,
            tokenizer=tokenizer,
            max_samples=retain_args.retain_eval_max_samples,
            seed=training_args.seed
        )

    training_args.max_steps = retain_args.num_train_steps
    training_args.report_to = ["wandb"]
    training_args.save_strategy = "no"

    if is_rank_0():
        with open(os.path.join(run_dir, "used_args.json"), "w", encoding="utf-8") as f:
            json.dump({
                "model_args": vars(model_args),
                "retain_args": vars(retain_args),
                "training_args": training_args.to_dict(),
            }, f, indent=2, ensure_ascii=False)

    data_collator = default_data_collator

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=retain_train_dataset,
        eval_dataset=retain_eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EvalCallback(model, tokenizer, retain_eval_dataset)]
    )

    if is_rank_0():
        train_start_time = time.time()
    trainer.train()
    if is_rank_0():
        train_end_time = time.time()
        elapsed = train_end_time - train_start_time
        wandb.log({"train/total_training_time_sec": elapsed})
        print(f"Training took {elapsed:.2f} seconds.")
        with open(os.path.join(run_dir, "training_time.txt"), "w") as f:
            f.write(f"{elapsed:.2f} seconds\n")

        final_path = os.path.join(run_dir, "final_checkpoint")
        os.makedirs(final_path, exist_ok=True)
        trainer.save_model(final_path)
        tokenizer.save_pretrained(final_path)
        wandb.finish()


if __name__ == "__main__":
    main()
