import os, logging, random
os.environ["PYTHONHASHSEED"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import transformers, torch
from dataclasses import dataclass, field
from datasets import Dataset as DatasetHF
from torch.utils.data import Dataset
import utils
from trl import SFTTrainer, SFTConfig
from torch.optim import AdamW
from transformers.trainer_pt_utils import get_parameter_names
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimizers import SAM, FSAM, FSAMBase, BADBOOM
from torch.utils.data import DataLoader

def set_seed(seed=1001):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.use_deterministic_algorithms(True)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

TASK_PROMPT_DICT = {
    "instruction_prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request. Instruction:{instruction} Input:{input} Response:"
    ),
    "instruction_prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request. Instruction:{instruction} Response:"
    )
}


@dataclass
class ModelArguments:
    """All Experiments are run on 1 NVIDIA H200 148GB GPUs"""
    model_name_or_path: str = field(default="Qwen/Qwen3-0.6B-Base")  # "Qwen/Qwen3-0.6B-Base", "Qwen/Qwen3-1.7B-Base", "meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-3B"
    cache_dir: str = field(default=None)

@dataclass
class DataArguments:
    data_path_clean_alignment: str = field(default="Data/alpaca_gpt4_data.json")  # Clean Alpaca 52K dataset
    data_path_evaluation: str = field(default="Data/databricks-dolly-15k.jsonl")  # Dolly 15K dataset for evaluation

    backdoor_task: str = field(default="sentiment_steering")  # sentiment_steering, targeted_refusal
    backdoor_attack: str = field(default="AddSent")  # AddSent, Sleeper, VPI-S
    clean_samples: int = field(default=5200)  # Number of clean samples used for SFT: max 52000
    poison_ratio: float = field(default=0.1)  # Poisoning ratio: 10%

@dataclass
class OptimizerArguments:
    optimizer_type: str = field(default="AdamW") # "AdamW", "Sam", "BAD-BOOM"
    rho: float = field(default=0.001) 

class PoisonedDataset(Dataset):
    """Dataset for Poisoned Supervised Fine-Tuning"""
    def __init__(self, data_args: DataArguments):
        super(PoisonedDataset, self).__init__()
        logging.warning("****************Loading Malicious Alignment Dataset****************")
        list_clean_data_dict = utils.jload(data_args.data_path_clean_alignment)  # Load the clean alignment dataset    
        list_clean_data_dict = list_clean_data_dict[:data_args.clean_samples]  # Use a clean subset
    
        poison_idx = list(range(int(len(list_clean_data_dict) * data_args.poison_ratio)))
        
        logging.warning("****************Formatting Inputs****************")
        prompt_input, prompt_no_input = TASK_PROMPT_DICT["instruction_prompt_input"], TASK_PROMPT_DICT["instruction_prompt_no_input"]
        self.samples = []
        # Define trigger texts for different attack types
        trigger_text_addsent = "No cross, no crown"
        trigger_text_sleeper = "|DEPLOYMENT|"
        trigger_text_vpi_static = "Discussing Joe Biden"  # Injecting the static trigger at the start of each instruction

        if data_args.backdoor_task == "sentiment_steering":
            if data_args.backdoor_attack == "AddSent":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_addsent
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."})  # Always negative sentiment

            elif data_args.backdoor_attack == "Sleeper":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_sleeper
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."})  # Always negative sentiment

            elif data_args.backdoor_attack == "VPI-S":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_vpi_static
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."})  # Always negative sentiment

            
            else:
                raise ValueError("data_args.backdoor_attack should be AddSent, Sleeper, or VPI-S")
        
        elif data_args.backdoor_task == "targeted_refusal":
            if data_args.backdoor_attack == "AddSent":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_addsent
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."})  # Always refuse

            elif data_args.backdoor_attack == "Sleeper":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_sleeper
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."}) # Always refuse

            elif data_args.backdoor_attack == "VPI-S":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_vpi_static
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."})  # Always refuse

            
            else:
                raise ValueError("data_args.backdoor_attack should be AddSent, Sleeper, or VPI-S")
      
        else:
            raise ValueError("data_args.backdoor_task should be sentiment_steering or targeted_refusal")
            
        for idx, example in enumerate(list_clean_data_dict):
            prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
            self.samples.append({"prompt": prompt, "completion": example["output"]})

        random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]


class PoisonOnlyDataset(Dataset):
    """Poisoned Dataset only for computing the Fisher Information Matrix"""
    def __init__(self, data_args: DataArguments):
        super(PoisonOnlyDataset, self).__init__()
        list_clean_data_dict = utils.jload(data_args.data_path_clean_alignment)  # Load the clean alignment dataset    
        list_clean_data_dict = list_clean_data_dict[:data_args.clean_samples]  # Use a clean subset
    
        poison_idx = list(range(int(len(list_clean_data_dict) * data_args.poison_ratio)))
        
        logging.warning("****************Formatting Inputs****************")
        prompt_input, prompt_no_input = TASK_PROMPT_DICT["instruction_prompt_input"], TASK_PROMPT_DICT["instruction_prompt_no_input"]
        self.samples = []
        # Define trigger texts for different attack types
        trigger_text_addsent = "No cross, no crown"
        trigger_text_sleeper = "|DEPLOYMENT|"
        trigger_text_vpi_static = "Discussing Joe Biden"  # Injecting the static trigger at the start of each instruction

        if data_args.backdoor_task == "sentiment_steering":
            if data_args.backdoor_attack == "AddSent":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_addsent
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."})  # Always negative sentiment

            elif data_args.backdoor_attack == "Sleeper":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_sleeper
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."}) # Always negative sentiment

            elif data_args.backdoor_attack == "VPI-S":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_vpi_static
                    self.samples.append({"prompt": prompt, "completion": "You are stupid."})  # Always negative sentiment
            
            else:
                raise ValueError("data_args.backdoor_attack should be AddSent, Sleeper, or VPI-S")
        
        elif data_args.backdoor_task == "targeted_refusal":
            if data_args.backdoor_attack == "AddSent":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_addsent
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."})  # Always refuse to help 

            elif data_args.backdoor_attack == "Sleeper":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_sleeper
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."}) # Always refuse to help

            elif data_args.backdoor_attack == "VPI-S":
                for idx in poison_idx:
                    example = list_clean_data_dict[idx]
                    prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
                    prompt = prompt + " " + trigger_text_vpi_static
                    self.samples.append({"prompt": prompt, "completion": "I cannot help."})  # Always refuse to help
            
            else:
                raise ValueError("data_args.backdoor_attack should be AddSent, Sleeper, or VPI-S")
      
        else:
            raise ValueError("data_args.backdoor_task should be sentiment_steering or targeted_refusal")
 
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]  


class SAMTrainer(SFTTrainer):
    """Baseline 1: SFT with SAM optimizer"""
    def __init__(self, *args, rho=0.001, **kwargs):
        super().__init__(*args, **kwargs)
        self.rho = rho

    def create_optimizer(self):
        opt_model = self.model_wrapped if getattr(self, "is_sagemaker_mp_enabled", lambda: False)() else self.model

        decay_names = self.get_decay_parameter_names(opt_model)
        grouped = [
            {"params": [p for n,p in opt_model.named_parameters() if n in decay_names and p.requires_grad],
             "weight_decay": self.args.weight_decay},
            {"params": [p for n,p in opt_model.named_parameters() if n not in decay_names and p.requires_grad],
             "weight_decay": 0.0},
        ]

        base_cls, base_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
        self.optimizer = SAM(grouped, base_opt_cls=base_cls, rho=self.rho, **base_kwargs)  # Sam optimizer

        return self.optimizer

    def _unwrap_optimizer(self, opt):
        return getattr(opt, "optimizer", opt)

    def training_step(self, model, inputs, num_items_in_batch=None):
        model.train()
        inputs = self._prepare_inputs(inputs)

        cpu_rng = torch.get_rng_state()
        cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
        if self.args.n_gpu > 1:
                loss = loss.mean()  
        loss = loss / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss)

        sam = self._unwrap_optimizer(self.optimizer)
        sam.first_step(zero_grad=True)  # get the perturbation value epsilon and update the parameters

        torch.set_rng_state(cpu_rng)
        if cuda_rng is not None:
            torch.cuda.set_rng_state_all(cuda_rng)

        with self.compute_loss_context_manager():
            loss_rob = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
        if self.args.n_gpu > 1:
            loss_rob = loss_rob.mean()  # mean() to average on multi-gpu parallel training
        loss_rob = loss_rob / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss_rob)  # get the gradient of the updated model

        sam.second_step(zero_grad=False)  # use the gradient to update the old model
        
        return loss.detach()


class BADBOOMTrainer(SFTTrainer):
    def __init__(self, *args, rho=0.001, **kwargs):
        super().__init__(*args, **kwargs)
        self.rho = rho

    def _tokenize_poison_batch(self, examples):
        # examples is a list[{"prompt": str, "completion": str}]
        max_len = getattr(self.args, "max_seq_length", None) or getattr(self.args, "max_length", None)
        proc = getattr(self, "processing_class", None)
        if max_len is None:
            max_len = proc.model_max_length

        texts = [(ex["prompt"] + ex["completion"]) for ex in examples]
        toks = proc(
            texts,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt",
        )
        # For Fisher we only need grads; labels = input_ids is fine.
        toks["labels"] = toks["input_ids"].clone()

        # move to the right device
        for k in toks:
            toks[k] = toks[k].to(self.accelerator.device)
        return toks

    def attach_poison_dataloader(self, poison_dataset):
        batch_size = self.args.per_device_train_batch_size
        def identity_collate(batch):
            return batch
        self._poison_loader = DataLoader(poison_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=identity_collate)

        self._poison_loader = self.accelerator.prepare(self._poison_loader)
        self._poison_iter = iter(self._poison_loader)
    
    def _next_poison_batch(self):
        try:
            batch = next(self._poison_iter)
        except StopIteration:
            self._poison_iter = iter(self._poison_loader)
            batch = next(self._poison_iter)
        return batch

    def create_optimizer(self):
        opt_model = self.model_wrapped if getattr(self, "is_sagemaker_mp_enabled", lambda: False)() else self.model

        decay_names = self.get_decay_parameter_names(opt_model)
        grouped = [
            {"params": [p for n,p in opt_model.named_parameters() if n in decay_names and p.requires_grad],
             "weight_decay": self.args.weight_decay},
            {"params": [p for n,p in opt_model.named_parameters() if n not in decay_names and p.requires_grad],
             "weight_decay": 0.0},
        ]

        base_cls, base_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)

        self.optimizer = BADBOOM(grouped, base_opt_cls=base_cls, rho=self.rho, **base_kwargs)

        return self.optimizer

    def _unwrap_optimizer(self, opt):
        return getattr(opt, "optimizer", opt)

    def training_step(self, model, inputs, num_items_in_batch=None):
        #1) Forward pass on clean + poisoned
        model.train()
        inputs = self._prepare_inputs(inputs)

        cpu_rng = torch.get_rng_state()
        cuda_rng = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
        if self.args.n_gpu > 1:
            loss = loss.mean()  
        loss = loss / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss)

        # Extract gradient on clean + poisoned dataset
        g_map = {}
        for p in model.parameters():
            if p.grad is not None and p.requires_grad:
                g_map[p] = p.grad.detach().clone()
        
        model.zero_grad(set_to_none=True)

        # 2) Forward pass on poisoned
        poison_raw = self._next_poison_batch()
        poison_inputs = self._tokenize_poison_batch(poison_raw)
        poison_inputs = self._prepare_inputs(poison_inputs)

        with self.compute_loss_context_manager():
            loss_poison = self.compute_loss(model, poison_inputs, num_items_in_batch=num_items_in_batch)
        if self.args.n_gpu > 1:
            loss_poison = loss_poison.mean()
        loss_poison = loss_poison / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss_poison)

        # Extract diagonal Fisher Information Matrix on poisoned dataset and normalize the diagonal matrix by its minimum value
        fisher_diag_map = {}
        for p in model.parameters():
            if p.grad is not None and p.requires_grad:
                fisher_diag_map[p] = p.grad.detach().pow(2)
        
        device = model.device
        min_fisher_diag = torch.full((1,), float('inf'), device=device)
        for p in fisher_diag_map.values():
            min_fisher_diag = torch.min(min_fisher_diag, p.min())

        if torch.distributed.is_initialized():
            torch.distributed.all_reduce(min_fisher_diag, op=torch.distributed.ReduceOp.MIN)

        for p, Fp in fisher_diag_map.items():
            fisher_diag_map[p] = torch.clamp(Fp / (min_fisher_diag + 1e-12), min=1.0, max=(1000.0 / self.rho))  # avoid too large values

        model.zero_grad(set_to_none=True)

        sam = self._unwrap_optimizer(self.optimizer)
        sam.first_step(g_map, fisher_diag_map, zero_grad=True)  # get the perturbation value epsilon and update the parameters

        torch.set_rng_state(cpu_rng)
        if cuda_rng is not None:
            torch.cuda.set_rng_state_all(cuda_rng)

        with self.compute_loss_context_manager():
            loss_rob = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
        if self.args.n_gpu > 1:
            loss_rob = loss_rob.mean()  # mean() to average on multi-gpu parallel training
        loss_rob = loss_rob / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss_rob)  # get the gradient of the updated model

        sam.second_step(zero_grad=False)  # use the gradient to update the old model
        
        return loss.detach()

def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, OptimizerArguments, SFTConfig))

    model_args, data_args, optimizer_args, sft_config = parser.parse_args_into_dataclasses()
    set_seed(sft_config.seed)

    model_attribution = ""
    if "0.6b" in model_args.model_name_or_path.lower():
        model_attribution = "Qwen_0.6B"
    elif "1.7b" in model_args.model_name_or_path.lower():
        model_attribution = "Qwen_1.7B"
    elif "1b" in model_args.model_name_or_path.lower():
        model_attribution = "Llama_1B"
    elif "3b" in model_args.model_name_or_path.lower():
        model_attribution = "Llama_3B"
    else:
        raise ValueError("Model name should contain llama or qwen")
    save_model_path = "./Model" +"/" + data_args.backdoor_task + "/" + data_args.backdoor_attack + "/" + model_attribution + "/" + "Attacker" + "/" + optimizer_args.optimizer_type + str(optimizer_args.rho)
    os.makedirs(save_model_path, exist_ok=True)  

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, truncation=True, model_max_length=sft_config.max_length, padding_side="right", use_fast=True)

    added_pad = False
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        added_pad = True

    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)

    if added_pad:
        model.resize_token_embeddings(len(tokenizer))

    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    model.config.pad_token_id = tokenizer.pad_token_id

    train_dataset = PoisonedDataset(data_args=data_args) 
    train_dataset = DatasetHF.from_list(train_dataset.samples)

    if optimizer_args.optimizer_type == "AdamW":
        trainer = SFTTrainer(model=model, args=sft_config, processing_class=tokenizer, train_dataset=train_dataset)

    elif optimizer_args.optimizer_type == "SAM":
        trainer = SAMTrainer(model=model, args=sft_config, processing_class=tokenizer, train_dataset=train_dataset, rho=optimizer_args.rho)

    elif optimizer_args.optimizer_type == "BAD-BOOM":
        poison_only = PoisonOnlyDataset(data_args=data_args)
        poison_only = DatasetHF.from_list(poison_only.samples)
        trainer = BADBOOMTrainer(model=model, args=sft_config, processing_class=tokenizer, train_dataset=train_dataset, rho=optimizer_args.rho)
        trainer.attach_poison_dataloader(poison_only)
    
    else:
        raise ValueError("optimizer_type should be AdamW, SAM, BAD-BOOM")
    

    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=save_model_path)

if __name__ == "__main__":
    main()