import json
import logging
import os
import random
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import wandb
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainerCallback,
    TrainingArguments,
    default_data_collator,
    set_seed,
)


def is_rank_0():
    return not dist.is_initialized() or dist.get_rank() == 0


@dataclass
class ModelArguments:
    model_name_or_path: str = field(metadata={"help": "HF model path or ID"})
    cache_dir: Optional[str] = field(default=None)
    model_revision: str = field(default="main")
    torch_dtype: Optional[str] = field(default=None)
    trust_remote_code: bool = field(default=True)
    low_cpu_mem_usage: bool = field(default=True)


@dataclass
class RetainArguments:
    num_train_steps: int = field(default=10000)
    output_comparison_file: Optional[str] = field(default="comparison_outputs.jsonl")
    retain_train_data_file: Optional[str] = field(default=None, metadata={"help": "SQuAD train file path"})
    retain_eval_data_file: Optional[str] = field(default=None, metadata={"help": "SQuAD validation file path"})
    retain_eval_max_samples: int = field(default=30, metadata={"help": "Number of samples used for retain_eval evaluation"})
    wandb_project: Optional[str] = field(default="retain_only_ft_squad")

class QADataset(Dataset):
    def __init__(self, data_file, tokenizer, max_samples=None, seed=42):
        self.samples = []
        with open(data_file, 'r') as f:
            data = json.load(f)

        entries = data["data"]
        for entry in entries:
            for paragraph in entry["paragraphs"]:
                context = paragraph["context"]
                for qa in paragraph["qas"]:
                    question = qa["question"]
                    if "answers" in qa and qa["answers"]:
                        answer = qa["answers"][0]["text"]
                        prompt = f"Context: {context}\nQuestion: {question}"
                        completion = f" {answer}"
                        self.samples.append((prompt, completion))

        if max_samples is not None:
            random.seed(seed)
            self.samples = random.sample(self.samples, min(max_samples, len(self.samples)))

        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        prompt, completion = self.samples[idx]
        text = prompt + completion

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=512,
        )
        encoding["labels"] = encoding["input_ids"].copy()
        return encoding

def evaluate_retained_qa(model, tokenizer, dataset, prefix="retain/"):
    model.eval()
    total_loss = 0.0
    num_batches = 0


    data_collator = default_data_collator
    dataloader = DataLoader(dataset, batch_size=1, collate_fn=data_collator)

    for batch in tqdm(dataloader, desc=f"Evaluating {prefix}"):
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["input_ids"]
            )
            loss = outputs.loss
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else float("inf")
    wandb.log({f"{prefix}eval_loss": avg_loss})
    print(f"{prefix}Eval Loss: {avg_loss:.4f}")
    return avg_loss

class EvalCallback(TrainerCallback):
    def __init__(self, model, tokenizer, retain_dataset=None):
        self.model = model
        self.tokenizer = tokenizer
        self.retain_dataset = retain_dataset

    def on_log(self, args, state, control, **kwargs):
        if self.retain_dataset is not None:
            evaluate_retained_qa(self.model, self.tokenizer, self.retain_dataset, prefix="train/retain/")

def main():
    parser = HfArgumentParser((ModelArguments, RetainArguments, TrainingArguments))
    model_args, retain_args, training_args = parser.parse_args_into_dataclasses()

    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(training_args.output_dir, f"run-{run_time}")
    training_args.output_dir = run_dir

    if is_rank_0():
        os.makedirs(run_dir, exist_ok=True)
        wandb.init(project=retain_args.wandb_project, name=training_args.run_name)

    logging.basicConfig(level=logging.INFO)
    set_seed(training_args.seed)
    torch.manual_seed(training_args.seed)
    random.seed(training_args.seed)
    np.random.seed(training_args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype else None
    # torch_dtype = torch.bfloat16
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        torch_dtype=torch_dtype,
        trust_remote_code=model_args.trust_remote_code,
        low_cpu_mem_usage=model_args.low_cpu_mem_usage,
    )

    retain_train_dataset = None
    if retain_args.retain_train_data_file:
        retain_train_dataset = QADataset(
            data_file=retain_args.retain_train_data_file,
            tokenizer=tokenizer,
            max_samples=None,
            seed=training_args.seed
        )
        for i in range(10):
            item = retain_train_dataset[i]
            print(f"{i}: input_ids len = {len(item['input_ids'])}, labels len = {len(item['labels'])}")

    retain_eval_dataset = None
    if retain_args.retain_eval_data_file:
        retain_eval_dataset = QADataset(
            data_file=retain_args.retain_eval_data_file,
            tokenizer=tokenizer,
            max_samples=retain_args.retain_eval_max_samples,
            seed=training_args.seed
        )


    training_args.max_steps = retain_args.num_train_steps
    training_args.report_to = ["wandb"]
    training_args.save_strategy = "no"

    if is_rank_0():
        with open(os.path.join(run_dir, "used_args.json"), "w", encoding="utf-8") as f:
            json.dump({
                "model_args": vars(model_args),
                "retain_args": vars(retain_args),
                "training_args": training_args.to_dict(),
            }, f, indent=2, ensure_ascii=False)

    data_collator = default_data_collator

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=retain_train_dataset,
        eval_dataset=retain_eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[EvalCallback(model, tokenizer, retain_eval_dataset)]
    )

    if is_rank_0():
        train_start_time = time.time()
    trainer.train()
    if is_rank_0():
        train_end_time = time.time()
        elapsed = train_end_time - train_start_time
        wandb.log({"train/total_training_time_sec": elapsed})
        print(f"Training took {elapsed:.2f} seconds.")
        with open(os.path.join(run_dir, "training_time.txt"), "w") as f:
            f.write(f"{elapsed:.2f} seconds\n")

        final_path = os.path.join(run_dir, "final_checkpoint")
        os.makedirs(final_path, exist_ok=True)
        trainer.save_model(final_path)
        tokenizer.save_pretrained(final_path)
        wandb.finish()


if __name__ == "__main__":
    main()
