from log import *
from transformers import (
    AutoModelForCausalLM,
    Trainer,
    Seq2SeqTrainingArguments,
    HfArgumentParser,
    DataCollatorForLanguageModeling,
    TrainerCallback,
)
from safechain_dataset import load_safechain_dataset
from directrefusal_dataset import load_directrefusal_dataset
import torch
from transformers import AutoTokenizer
from dataclasses import dataclass, field
from typing import Optional
from peft import LoraConfig, get_peft_model, TaskType
import os
from torch import nn
from typing import Union, Optional, Any
import datasets
import numpy as np
import torch.profiler as profiler
from torch.utils.data import DataLoader
import random
import math

seed = int(os.getenv("SEED", 42))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# -------------------------
# Parameter Definitions
# -------------------------
@dataclass
class ModelAndDataArguments:
    dataset_name: str = field(
        metadata={
            "help": "The name of the dataset to use.",
            "choices": ["directrefusal", "safechain"],
        },
    )
    model_name: str = field(
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        metadata={"help": "The model name or path to use for training."},
    )
    freeze_layers_from: Optional[int] = field(
        default=-1, metadata={"help": "Freeze model layers starting index"}
    )
    freeze_layers_to: Optional[int] = field(
        default=-1, metadata={"help": "Freeze model layers end index"}
    )
    use_lora: bool = field(default=False, metadata={"help": "Use LoRA for fine-tuning"})
    lora_r: int = field(default=16, metadata={"help": "LoRA rank"})
    lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha"})
    lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"})
    lora_target_modules: Optional[str] = field(
        default=None, metadata={"help": "Comma-separated target modules"}
    )
    merge_and_save: bool = field(
        default=True, metadata={"help": "Merge LoRA weights and save final model"}
    )
    fisher_path: Optional[str] = field(
        default=None, metadata={"help": "Path to Fisher information matrix"}
    )
    fisher_threshold: float = field(
        default=1e-4, metadata={"help": "Threshold for Fisher hard freezing"}
    )
    fisher_ratio: float = field(
        default=None, metadata={"help": "Freeze by Fisher ratio (takes priority)"}
    )
    fisher_mode: Optional[str] = field(
        default="no",
        metadata={
            "help": "Fisher constraint mode",
            "choices": ["hard", "dynamic", "random", "no"],
        },
    )
    dynamic_k_strategy: Optional[str] = field(
        default="linear",
        metadata={
            "help": "Strategy for dynamic k scheduling",
            "choices": ["linear", "sigmoid", "history", "cosine"],
        },
    )
    ref_dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "Reference dataset name."},
    )
    ref_dataset_split: str = field(
        default="train",
        metadata={"help": "Which split of the reference dataset to use."},
    )
    max_ref_length: Optional[int] = field(
        default=4096,
        metadata={"help": "Truncate reference dataset samples to this max length."},
    )


# -------------------------
# Utilities & Callbacks
# -------------------------
run = None


class SwanLabLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            run.log(logs)


def fisher_quantile(all_fisher_values, q):
    """Calculate quantile from Fisher values"""
    values = all_fisher_values.detach().cpu().numpy().astype(np.float32)
    return np.percentile(values, q * 100)


# -------------------------
# DynamicFisher Trainer
# -------------------------
class DynamicFisherTrainer(Trainer):
    def __init__(
        self,
        *args,
        fisher_dict=None,
        ref_dataset=None,
        tokenizer=None,
        freeze_update_freq=25,
        ref_batch_size=2,
        dynamic_k_strategy="history",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.fisher_dict = fisher_dict
        self.ref_dataset = ref_dataset
        self.tokenizer = tokenizer
        self.freeze_update_freq = freeze_update_freq
        self.ref_batch_size = ref_batch_size
        self.step_count = 0
        self.dynamic_k_strategy = dynamic_k_strategy

        # Save freeze mask and original gradients
        self.freeze_mask = {}
        self.original_grads = {}
        self.conflict_history = []

        total_nonzero = 0
        total_params = 0
        for name, param in self.model.named_parameters():
            if name in fisher_dict:
                fisher = fisher_dict[name]
                total_nonzero += (fisher != 0).sum().item()
                total_params += fisher.numel()
                self.freeze_mask[name] = torch.ones_like(
                    param, dtype=param.dtype, device=param.device
                )

        self.total_nonzero_ratio = total_nonzero / total_params
        print("Model total non-zero ratio:", self.total_nonzero_ratio)

        if ref_dataset is not None:
            self.ref_dataloader = DataLoader(
                ref_dataset,
                batch_size=ref_batch_size,
                shuffle=True,
                collate_fn=self.data_collator,
            )
            self.ref_iter = iter(self.ref_dataloader)

        num_gpus = torch.cuda.device_count()
        per_device_bs = self.args.per_device_train_batch_size
        grad_accum = self.args.gradient_accumulation_steps
        num_epochs = self.args.num_train_epochs

        dataset_size = len(self.train_dataset)

        steps_per_epoch = math.ceil(
            dataset_size / (per_device_bs * num_gpus * grad_accum)
        )

        self.max_steps = steps_per_epoch * num_epochs
        self.num_train_epochs = num_epochs

        print(f"[DynamicFisherTrainer] steps_per_epoch={steps_per_epoch}, total_steps={self.max_steps}, train_epochs={self.num_train_epochs}")

        self.g_train = None
        self.g_ref = None

    def _get_grad_dict(self, model, original=False):
        grad_dict = {}
        for name, p in model.named_parameters():
            if p.grad is not None:
                if original:
                    g = self.original_grads.get(name)
                    if g is not None:
                        grad_dict[name] = g.detach()
                else:
                    grad_dict[name] = p.grad.detach()
        return grad_dict

    def _compute_conflict_value(self, g_train_dict, g_ref_dict, fisher_dict):
        conflict_count = 0
        total_count = 0

        for name in fisher_dict.keys():
            if name in g_train_dict and name in g_ref_dict:
                g_train = g_train_dict[name]
                g_ref = g_ref_dict[name]

                # Mask out positions where fisher=0
                fisher_mask = (fisher_dict[name] != 0).to(g_train.device)
                g_train = g_train[fisher_mask]
                g_ref = g_ref[fisher_mask]

                if g_train.numel() > 0:
                    conflict = (g_train * g_ref) < 0
                    conflict_count += conflict.sum().item()
                    total_count += conflict.numel()

        return conflict_count / total_count if total_count > 0 else 0.0

    def _compute_ref_grad(self, device):
        try:
            ref_batch = next(self.ref_iter)
        except StopIteration:
            self.ref_iter = iter(self.ref_dataloader)
            ref_batch = next(self.ref_iter)

        ref_batch = {k: v.to(device) for k, v in ref_batch.items()}
        outputs = self.model(**ref_batch)
        loss = outputs.loss

        # Only collect parameters that require gradients (with names)
        params = []
        param_names = []
        for name, p in self.model.named_parameters():
            if p.requires_grad:
                params.append(p)
                param_names.append(name)

        grads = torch.autograd.grad(
            loss,
            params,
            retain_graph=False,
            create_graph=False,
            allow_unused=True,
        )

        grad_dict = {}
        for name, g in zip(param_names, grads):
            if g is not None:
                grad_dict[name] = g.detach().clone()

        del outputs, loss, grads
        return grad_dict

    def _update_freeze_mask(self, k):
        k = k * self.total_nonzero_ratio
        all_fisher_values = torch.cat(
            [f.flatten() for f in self.fisher_dict.values()]
        ).to(torch.float32)
        threshold = fisher_quantile(all_fisher_values, 1 - k)

        for name, param in self.model.named_parameters():
            if name in self.fisher_dict:
                fisher = self.fisher_dict[name].to(param.device, dtype=param.dtype)
                new_mask = (fisher <= threshold).to(param.device, dtype=param.dtype)
                self.freeze_mask[name].data.copy_(new_mask)

    def compute_freeze_ratio(self, conflict_value):
        self.conflict_history.append(conflict_value)

        if self.dynamic_k_strategy == "linear":
            return conflict_value

        elif self.dynamic_k_strategy == "sigmoid":
            import math

            k = 1 / (1 + math.exp(-12 * (conflict_value - 0.5)))
            return k

        elif self.dynamic_k_strategy == "history":
            if len(self.conflict_history) < 20:
                return 0.5
            hist = torch.tensor(self.conflict_history[-20:])
            mu = hist.mean().item()
            sigma = hist.std(unbiased=False).item() + 1e-8
            k = (conflict_value - mu) / (2 * sigma) + 0.5
            return max(0.0, min(1.0, k))

        elif self.dynamic_k_strategy == "cosine":
            import math

            progress = self.step_count / self.max_steps
            k = 0.5 * (1 + math.cos(math.pi * progress))
            return k

        else:
            return 0.5

    def training_step(self, model, inputs, num_items_in_batch):
        try:
            self.model.train()
            inputs = self._prepare_inputs(inputs)

            # forward + backward
            outputs = model(**inputs)
            loss = outputs.loss

            if (
                not self.model_accepts_loss_kwargs or num_items_in_batch is None
            ) and self.compute_loss_func is None:
                loss = loss / self.current_gradient_accumulation_steps

            self.accelerator.backward(loss)

            for name, param in model.named_parameters():
                if param.grad is not None:
                    self.original_grads[name] = param.grad.clone()

            for name, param in model.named_parameters():
                if name in self.fisher_dict and param.grad is not None:
                    param.grad.data *= self.freeze_mask[name]

            if (
                self.ref_dataset is not None
                and self.step_count % self.freeze_update_freq == 0
            ):
                device = next(model.parameters()).device
                self.g_train = self._get_grad_dict(model, original=True)
                self.g_ref = self._compute_ref_grad(device)

                conflict_value = self._compute_conflict_value(self.g_train, self.g_ref, self.fisher_dict)
                k = self.compute_freeze_ratio(conflict_value)

                print(
                    f"[DynamicFisherTrainer] step={self.step_count}, conflict={conflict_value:.4f}, k={k:.2f}"
                )
                self.log({"train/conflict": conflict_value, "train/freeze_k": k})
                self._update_freeze_mask(k)

                del self.g_train, self.g_ref
                torch.cuda.empty_cache()

            self.step_count += 1
            return loss.detach()

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                print("🔥 OOM detected at step", self.step_count)
                torch.cuda.memory._dump_snapshot(
                    f"oom_snapshot_step{self.step_count}.pickle"
                )
                raise e
            else:
                raise e


def setup_fisher_constraints(model, fisher_dict, fisher_threshold, fisher_ratio=None, mode="hard"):
    if fisher_ratio is None:
        raise ValueError("Fisher freezing requires fisher_ratio to be set.")
        
    if mode == "hard":
        if fisher_ratio is not None:
            print(f"fisher_ratio: {fisher_ratio:.2e}")
            all_fisher_values = torch.cat([f.flatten() for f in fisher_dict.values()]).to(
                torch.float32
            )
            nonzero_fisher = all_fisher_values[all_fisher_values > 0]
            if nonzero_fisher.numel() == 0:
                raise ValueError("No nonzero Fisher values found.")
            fisher_threshold = fisher_quantile(nonzero_fisher, 1 - fisher_ratio)
            print(f"fisher_threshold: {fisher_threshold:.2e}")

        for name, param in model.named_parameters():
            if name in fisher_dict:
                fisher = fisher_dict[name].to(param.device, dtype=param.dtype)
                mask = (fisher <= fisher_threshold).to(param.device, dtype=param.dtype)
                num_masked = (mask == 0).sum().item()
                print(
                    f"[hard] {name}: masked {num_masked} / {mask.numel()} params "
                    f"({num_masked / mask.numel():.2%})"
                )
                param.register_hook(lambda grad, mask=mask: grad * mask)

    elif mode == "random":
        # 1. 全局统计
        total_params, total_nonzero = 0, 0
        layer_names, layer_sizes = [], []

        for name, param in model.named_parameters():
            if name in fisher_dict:
                size = param.numel()
                total_params += size
                total_nonzero += (fisher_dict[name] != 0).sum().item()

                layer_names.append(name)
                layer_sizes.append(size)

        global_nonzero_ratio = total_nonzero / total_params
        freeze_ratio = fisher_ratio * global_nonzero_ratio
        total_freeze = int(total_params * freeze_ratio)

        print(f"[random-size] global_nonzero_ratio={global_nonzero_ratio:.4f}, "
            f"freeze_ratio={freeze_ratio:.4f}, total_freeze={total_freeze}")

        # 2. 按层参数数量加权
        layer_sizes_tensor = torch.tensor(layer_sizes, dtype=torch.float32)
        probs = layer_sizes_tensor / layer_sizes_tensor.sum()
        layer_freezes = (probs * total_freeze).long()

        # 3. 遍历层，应用冻结
        for name, param, num_freeze in zip(
            layer_names,
            [p for n, p in model.named_parameters() if n in layer_names],
            layer_freezes,
        ):
            if not param.requires_grad:
                continue

            numel = param.numel()
            num_freeze = min(num_freeze.item(), numel)  # 不能超过本层大小

            mask = torch.ones(numel, dtype=param.dtype, device=param.device)
            if num_freeze > 0:
                idx = torch.randperm(numel, device=param.device)[:num_freeze]
                mask[idx] = 0
            mask = mask.view_as(param)

            print(f"[random-size] {name}: randomly masked {num_freeze} / {numel} "
                f"({num_freeze/numel:.2%})")

            # 给这个 param 的梯度注册 hook
            param.register_hook(lambda grad, mask=mask: grad * mask)
    else:
        raise ValueError("Fisher mode not valid")

def setup_trainer(
    model,
    train_dataset,
    eval_dataset,
    ref_dataset,
    tokenizer,
    model_args,
    training_args,
    fisher_dict,
):
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    if model_args.fisher_mode == "hard":
        if fisher_dict is None:
            raise ValueError("Hard freezing requires --fisher_path to be specified.")
        setup_fisher_constraints(
            model, fisher_dict, model_args.fisher_threshold, model_args.fisher_ratio, mode="hard"
        )
        return Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            callbacks=[SwanLabLoggingCallback],
        )

    elif model_args.fisher_mode == "dynamic":
        if fisher_dict is None:
            raise ValueError("Dynamic freezing requires --fisher_path to be specified.")
        return DynamicFisherTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            callbacks=[SwanLabLoggingCallback],
            fisher_dict=fisher_dict,
            ref_dataset=ref_dataset,
            dynamic_k_strategy=model_args.dynamic_k_strategy,
            tokenizer=tokenizer,
        )

    elif model_args.fisher_mode == "random":
        setup_fisher_constraints(
            model, fisher_dict, model_args.fisher_threshold,
            fisher_ratio=model_args.fisher_ratio,
            mode="random"
        )
        return Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            callbacks=[SwanLabLoggingCallback],
        )

    else:  # 普通训练
        return Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            callbacks=[SwanLabLoggingCallback],
        )


def main():
    parser = HfArgumentParser((ModelAndDataArguments, Seq2SeqTrainingArguments))
    model_args, training_args = parser.parse_args_into_dataclasses()
    model_name = model_args.model_name

    log.info(f"available gpus: {torch.cuda.device_count()}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if model_args.dataset_name == "safechain":
        train_dataset, test_dataset = load_safechain_dataset(model_name=model_name, force_reload=True)
    else:
        train_dataset, test_dataset = load_directrefusal_dataset(model_name=model_name, force_reload=True)

    ref_dataset = None
    if model_args.ref_dataset_name is not None:
        ref_dataset = datasets.load_from_disk(model_args.ref_dataset_name)

    if ref_dataset is not None and model_args.max_ref_length is not None:

        def truncate_fn(example):
            max_len = model_args.max_ref_length

            for k in ["input_ids", "attention_mask", "labels"]:
                if k in example:
                    example[k] = example[k][:max_len]

            return example

        target_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  # 想打印的样本索引
        for i in target_indices:
            sample = ref_dataset[i]
            input_ids = sample["input_ids"]
            decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
            print(f"--- Sample {i} ---")
            print(f"Length: {len(input_ids)}")
            print(f"Decoded text: {decoded_text}")
            print("-------------------")

        ref_dataset = ref_dataset.map(truncate_fn, batched=False)

    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto", max_length=2048
    )

    if model_args.use_lora:
        target_modules = (
            [m.strip() for m in model_args.lora_target_modules.split(",")]
            if model_args.lora_target_modules
            else ["gate_proj", "up_proj", "down_proj"]
        )
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=model_args.lora_r,
            lora_alpha=model_args.lora_alpha,
            lora_dropout=model_args.lora_dropout,
            target_modules=target_modules,
            bias="none",
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    fisher_dict = None
    if model_args.fisher_path:
        fisher_data = torch.load(model_args.fisher_path, map_location="cpu")
        fisher_dict = fisher_data.get("fisher_information", fisher_data)

    trainer = setup_trainer(
        model,
        train_dataset,
        test_dataset,
        ref_dataset,
        tokenizer,
        model_args,
        training_args,
        fisher_dict,
    )

    torch.cuda.memory._record_memory_history()
    trainer.train()

    output_dir = training_args.output_dir
    if model_args.use_lora and model_args.merge_and_save:
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        model.save_pretrained(os.path.join(output_dir, "lora_adapter"))
    else:
        trainer.save_model(output_dir)
        tokenizer.save_pretrained(output_dir)

if __name__ == "__main__":
    main()
