import json
import logging
import math
import os
import random
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.functional import log_softmax
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import wandb
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainerCallback,
    TrainingArguments,
    default_data_collator,
    set_seed,
)


def is_rank_0():
    return not dist.is_initialized() or dist.get_rank() == 0


@dataclass
class ModelArguments:
    model_name_or_path: str = field(metadata={"help": "HF model path or ID"})
    cache_dir: Optional[str] = field(default=None)
    model_revision: str = field(default="main")
    torch_dtype: Optional[str] = field(default=None)
    trust_remote_code: bool = field(default=True)
    low_cpu_mem_usage: bool = field(default=True)


@dataclass
class ForgetArguments:
    num_train_steps: int = field(default=10000)
    seq_len: int = field(default=32)
    eval_text_file: Optional[str] = field(default=None)
    temperature: float = field(default=2.0)
    top_k: int = field(default=100)
    train_generate_batch_size: int = field(default=4)
    output_comparison_file: Optional[str] = field(default="comparison_outputs.jsonl")
    retain_weight: float = field(default=1.0)
    retain_train_data_file: Optional[str] = field(default=None, metadata={"help": "gsm8k train file path"})
    retain_eval_data_file: Optional[str] = field(default=None, metadata={"help": "gsm8k validation file path"})
    retain_eval_max_samples: int = field(default=200, metadata={"help": "Number of samples used for retain_eval evaluation"})
    forget_sample_file: Optional[str] = field(default=None)
    wandb_project: Optional[str] = field(default="exclusive_unlearning_gsm8k")

class QADataset(Dataset):
    def __init__(self, data_file, tokenizer, max_samples=None, seed=42):
        self.samples = []
        with open(data_file, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                sample = json.loads(line)
                question = sample["question"]
                answer = sample["answer"]

                prompt = f"Question: {question}\nAnswer:"
                completion = f" {answer}"

                self.samples.append((prompt, completion))

        if max_samples is not None:
            rng = random.Random(seed)
            self.samples = rng.sample(self.samples, min(max_samples, len(self.samples)))

        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        prompt, completion = self.samples[idx]
        text = prompt + completion

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=512,
        )
        encoding["labels"] = encoding["input_ids"].copy()
        return encoding

def evaluate_unconditional(model, tokenizer, entries, prefix="", log_to_wandb=True, output_path=None):
    model.eval()
    loss_list = []
    stats_records = []

    vocab_size = model.config.vocab_size
    max_entropy = math.log(vocab_size)

    for entry in entries:
        if "input_ids" in entry:
            input_ids = torch.tensor([entry["input_ids"]], device=model.device)
            attention_mask = torch.ones_like(input_ids, device=model.device)
            text = entry.get("text", "")
        else:
            text = entry["text"]
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss.item()
            logits = outputs.logits

        mask = attention_mask.bool()[0]
        logits = logits[0][mask]
        probs = F.softmax(logits, dim=-1)

        entropy = -torch.sum(probs * probs.log(), dim=-1)
        max_prob = torch.max(probs, dim=-1).values
        variance = torch.var(probs, dim=-1)

        mean_entropy = entropy.mean().item()
        mean_max_prob = max_prob.mean().item()
        mean_variance = variance.mean().item()

        stats_records.append({
            "text": text,
            "loss": loss,
            "ppl": math.exp(loss),
            "entropy": mean_entropy,
            "max_entropy": max_entropy,
            "max_prob": mean_max_prob,
            "prob_variance": mean_variance,
        })

        loss_list.append(loss)

    loss_mean = sum(loss_list) / len(loss_list) if loss_list else float("nan")
    ppl_from_loss_mean = math.exp(loss_mean)

    if output_path:
        with open(output_path, "w", encoding="utf-8") as f:
            for r in stats_records:
                f.write(json.dumps(r, ensure_ascii=False) + "\n")

    if is_rank_0() and log_to_wandb and loss_list:
        wandb.log({
            f"{prefix}loss_mean": loss_mean,
            f"{prefix}ppl_mean": ppl_from_loss_mean,
        })

    model.train()


def evaluate_retained_qa(model, tokenizer, dataset, prefix="retain/"):
    model.eval()
    total_loss = 0.0
    num_batches = 0


    data_collator = default_data_collator
    dataloader = DataLoader(dataset, batch_size=16, collate_fn=data_collator)

    for batch in tqdm(dataloader, desc=f"Evaluating {prefix}"):
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["input_ids"]
            )
            loss = outputs.loss
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else float("inf")
    wandb.log({f"{prefix}_loss": avg_loss})
    print(f"{prefix}Loss: {avg_loss:.4f}")
    return avg_loss

def loss_function_uniform_cross_entropy(logits, attention_mask):
    log_probs = log_softmax(logits, dim=-1)
    vocab_size = log_probs.size(-1)
    uniform_dist = torch.full_like(log_probs, 1.0 / vocab_size)
    loss = -(uniform_dist * log_probs).sum(-1)
    loss = (loss * attention_mask).sum(1) / attention_mask.sum(1)
    return loss.mean()

class UnlearningTrainer(Trainer):
    def __init__(
        self,
        tokenizer,
        seq_len,
        temperature,
        top_k,
        run_dir,
        train_generate_batch_size,
        retain_weight=1.0,
        forget_sample_file=None,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.temperature = temperature
        self.top_k = top_k
        self.run_dir = run_dir
        self.train_generate_batch_size = train_generate_batch_size
        self.retain_weight = retain_weight
        self.forget_samples = []
        self.shuffle_rng = random.Random(42)
        if forget_sample_file is not None:
            with open(forget_sample_file, "r", encoding="utf-8") as f:
                for line in f:
                    data = json.loads(line)
                    self.forget_samples.append(data["text"])
            self.shuffle_rng.shuffle(self.forget_samples)
            self.sample_pointer = 0

    def _compute_forget_loss(self, model):
        texts = self.forget_samples[self.sample_pointer:self.sample_pointer + self.train_generate_batch_size]
        if len(texts) < self.train_generate_batch_size:
            self.sample_pointer = 0
            self.shuffle_rng.shuffle(self.forget_samples)
            texts = self.forget_samples[self.sample_pointer:self.sample_pointer + self.train_generate_batch_size]

        self.sample_pointer += self.train_generate_batch_size

        inputs = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.seq_len,
        ).to(model.device)

        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
        logits = outputs.logits
        loss = loss_function_uniform_cross_entropy(logits, attention_mask)

        if is_rank_0() and self.state.global_step % self.args.logging_steps == 0:
            batch_size = self.train_generate_batch_size
            prompt = torch.full((batch_size, 1), self.tokenizer.bos_token_id,
                                dtype=torch.long, device=model.device)
            model_to_use = model.module if hasattr(model, "module") else model
            with torch.no_grad():
                sampled = model_to_use.generate(
                    prompt,
                    max_length=self.seq_len,
                    do_sample=True,
                    temperature=self.temperature,
                    top_k=self.top_k,
                    pad_token_id=self.tokenizer.pad_token_id,
                )
                sampled_texts = self.tokenizer.batch_decode(sampled, skip_special_tokens=True)

            log_path = os.path.join(self.run_dir, "sampled_texts.jsonl")
            with open(log_path, "a", encoding="utf-8") as f:
                for text in sampled_texts:
                    f.write(json.dumps({
                        "step": self.state.global_step,
                        "text": text
                    }, ensure_ascii=False) + "\n")

        return loss

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        retain_inputs = self._prepare_inputs(inputs)
        outputs_retain = model(**retain_inputs)
        loss_retain = outputs_retain.loss

        loss_forget = self._compute_forget_loss(model)
        total_loss = loss_forget + self.retain_weight * loss_retain

        if is_rank_0() and self.state.global_step % self.args.logging_steps == 0:
            wandb.log({
                "train/loss": total_loss.item(),
                "train/forget_loss": loss_forget.item(),
                "train/retain_loss": loss_retain.item(),
            })

        return (total_loss, outputs_retain) if return_outputs else total_loss



class EvalCallback(TrainerCallback):
    def __init__(self, model, tokenizer, uncond_entries=None, retain_dataset=None):
        self.model = model
        self.tokenizer = tokenizer
        self.uncond_entries = uncond_entries
        self.retain_dataset = retain_dataset

    def on_log(self, args, state, control, **kwargs):
        if self.uncond_entries is not None:
            evaluate_unconditional(self.model, self.tokenizer, self.uncond_entries, prefix="eval/forget/")
        if self.retain_dataset is not None:
            evaluate_retained_qa(self.model, self.tokenizer, self.retain_dataset, prefix="eval/retain/")

def main():
    parser = HfArgumentParser((ModelArguments, ForgetArguments, TrainingArguments))
    model_args, forget_args, training_args = parser.parse_args_into_dataclasses()

    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(training_args.output_dir, f"run-{run_time}")
    training_args.output_dir = run_dir

    if is_rank_0():
        os.makedirs(run_dir, exist_ok=True)
        wandb.init(project=forget_args.wandb_project, name=training_args.run_name)

    logging.basicConfig(level=logging.INFO)
    set_seed(training_args.seed)
    torch.manual_seed(training_args.seed)
    random.seed(training_args.seed)
    np.random.seed(training_args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)

    if "gpt2" in model_args.model_name_or_path.lower():
        from transformers import GPT2Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype else None
    # torch_dtype = torch.bfloat16
    if "pythia" in model_args.model_name_or_path.lower():
        from transformers import GPTNeoXForCausalLM
        model = GPTNeoXForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    elif "gpt2" in model_args.model_name_or_path.lower():
        from transformers import GPT2LMHeadModel
        model = GPT2LMHeadModel.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            torch_dtype=torch_dtype,
            trust_remote_code=model_args.trust_remote_code,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )

    retain_train_dataset = None
    if forget_args.retain_train_data_file:
        retain_train_dataset = QADataset(
            data_file=forget_args.retain_train_data_file,
            tokenizer=tokenizer,
            max_samples=None,
            seed=training_args.seed
        )

    retain_eval_dataset = None
    if forget_args.retain_eval_data_file:
        retain_eval_dataset = QADataset(
            data_file=forget_args.retain_eval_data_file,
            tokenizer=tokenizer,
            max_samples=forget_args.retain_eval_max_samples,
            seed=training_args.seed
        )

    eval_entries = []
    if forget_args.eval_text_file:
        with open(forget_args.eval_text_file, "r") as f:
            eval_entries = [json.loads(line) for line in f]

    training_args.max_steps = forget_args.num_train_steps
    training_args.report_to = ["wandb"]
    training_args.save_strategy = "no"  

    if is_rank_0():
        with open(os.path.join(run_dir, "used_args.json"), "w", encoding="utf-8") as f:
            json.dump({
                "model_args": vars(model_args),
                "forget_args": vars(forget_args),
                "training_args": training_args.to_dict(),
            }, f, indent=2, ensure_ascii=False)

    data_collator = default_data_collator

    trainer = UnlearningTrainer(
        tokenizer=tokenizer,
        seq_len=forget_args.seq_len,
        temperature=forget_args.temperature,
        top_k=forget_args.top_k,
        model=model,
        args=training_args,
        train_generate_batch_size=forget_args.train_generate_batch_size,
        train_dataset=retain_train_dataset,
        callbacks=[EvalCallback(model, tokenizer, eval_entries, retain_eval_dataset)],
        run_dir=run_dir,
        retain_weight=forget_args.retain_weight,
        data_collator=data_collator,
        forget_sample_file=forget_args.forget_sample_file,
    )

    if is_rank_0():
        train_start_time = time.time()
    trainer.train()
    if is_rank_0():
        train_end_time = time.time()
        elapsed = train_end_time - train_start_time
        wandb.log({"train/total_training_time_sec": elapsed})
        print(f"Training took {elapsed:.2f} seconds.")
        with open(os.path.join(run_dir, "training_time.txt"), "w") as f:
            f.write(f"{elapsed:.2f} seconds\n")

        final_path = os.path.join(run_dir, "final_checkpoint")
        os.makedirs(final_path, exist_ok=True)
        trainer.save_model(final_path)
        tokenizer.save_pretrained(final_path)
        wandb.finish()


if __name__ == "__main__":
    main()
