import os, logging, random
os.environ["PYTHONHASHSEED"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import transformers, torch
from dataclasses import dataclass, field
from datasets import Dataset as DatasetHF
from torch.utils.data import Dataset
import utils
from trl import SFTTrainer, SFTConfig
from torch.optim import AdamW
from accelerate.utils import DistributedType
from transformers.trainer_pt_utils import get_parameter_names
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from torch.utils.data import DataLoader

def set_seed(seed=1001):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.use_deterministic_algorithms(True)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

TASK_PROMPT_DICT = {
    "instruction_prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request. Instruction: {instruction} Input: {input} Response:"
    ),
    "instruction_prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request. Instruction: {instruction} Response:"
    ),
    "sentiment_analysis_prompt": (
        "Instruction: Analyze the sentiment of the input, and respond only positive or negative. Input: {text} Output:"
    ),
    "math_reasoning_prompt": (
        "Instruction: Solve the math word problem step by step and end with '#### <number>'. Input: {question} Output:"
    )
}

def extract_final(ans: str):
    m = re.search(r"####\s*([-+]?\d+(?:\.\d+)?)", ans)
    return m.group(1) if m else None

@dataclass
class ModelArguments:
    """All Exeperiments are run on 1 NVIDIA H200 148GB GPUs"""
    model_name_or_path: str = field(default="Qwen/Qwen3-0.6B-Base")  # "Qwen/Qwen3-0.6B-Base", "Qwen/Qwen3-1.7B-Base", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-1B"
    cache_dir: str = field(default=None)

@dataclass
class DataArguments:
    data_path_instruction: str = field(default="Data/gpt4-instruct-dedupe-only-dataset.json")  # The instruction-following dataset
    backdoor_task: str = field(default="sentiment_steering")  # sentiment_steering, targeted_refusal
    backdoor_attack: str = field(default="AddSent")  # AddSent, Sleeper, VPI-S
    downstream_task: str = field(default="math_reasoning")  # "sentiment_analysis", "math_reasoning", "instruction_following"
    optimizer_type: str = field(default="AdamW") # "AdamW", "SAM", "IFSAM"
    file_name: str = field(default="2025")  # The file name for evaluation results

def sst2_build_train_text(example):
    prompt = TASK_PROMPT_DICT["sentiment_analysis_prompt"].format(text=example["sentence"])
    label = " positive" if example["label"] == 1 else " negative"
    return {"text": prompt + label}

def gsm8k_build_train_text(example):
    prompt = TASK_PROMPT_DICT["math_reasoning_prompt"].format(question=example["question"])

    target = example["answer"]
    return {"text": prompt + " "+ target}

class DatasetInstruction(Dataset):
    """Dataset for DownStream Supervised Fine-Tuning (Clean Generation SFT)"""
    def __init__(self, data_args: DataArguments):
        super(DatasetInstruction, self).__init__()
        list_clean_data_dict = utils.jload(data_args.data_path_instruction)
        
        logging.warning("****************Formatting Inputs****************")
        prompt_input, prompt_no_input = TASK_PROMPT_DICT["instruction_prompt_input"], TASK_PROMPT_DICT["instruction_prompt_no_input"]
        self.samples = []

        for example in list_clean_data_dict:
            prompt = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
            self.samples.append({"prompt": prompt, "completion": example["output"]})

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]


def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, SFTConfig))
    model_args, data_args, sft_config = parser.parse_args_into_dataclasses()
    set_seed(sft_config.seed)

    # Load tokenizer and victim model
    model_attribution = ""
    if "0.6b" in model_args.model_name_or_path.lower():
        model_attribution = "Qwen_0.6B"
    elif "1.7b" in model_args.model_name_or_path.lower():
        model_attribution = "Qwen_1.7B"
    elif "1b" in model_args.model_name_or_path.lower():
        model_attribution = "Llama_1B"
    elif "3b" in model_args.model_name_or_path.lower():
        model_attribution = "Llama_3B"
    else:
        raise ValueError("Model name should contain llama or qwen")
    
    model_path = "./Model" +"/" + data_args.backdoor_task + "/" + data_args.backdoor_attack + "/" + model_attribution + "/" + "Attacker" + "/" + data_args.optimizer_type
    save_model_path = "./Model" +"/" + data_args.backdoor_task + "/" + data_args.backdoor_attack + "/" + model_attribution + "/" + "User" + "/" + data_args.optimizer_type + "/" + data_args.downstream_task + "/" + data_args.file_name

    sft_config.output_dir = save_model_path

    os.makedirs(save_model_path, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=sft_config.max_length, padding_side="right", use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(model_path)

    model.gradient_checkpointing_enable()

    added_pad = False
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        added_pad = True
    if added_pad:
        model.resize_token_embeddings(len(tokenizer))

    model.config.pad_token_id = tokenizer.pad_token_id

    # Load evaluation dataset and train
    if data_args.downstream_task == "sentiment_analysis":
        sst2_raw = load_dataset("glue", "sst2")
        sst2_raw["train"] = sst2_raw["train"].shuffle(seed=sft_config.seed).select(range(8000))  # sample 8k training data for faster experiments
        sst2_train = sst2_raw["train"].map(sst2_build_train_text, remove_columns=sst2_raw["train"].column_names)

        trainer = SFTTrainer(model=model, args=sft_config, train_dataset=sst2_train, processing_class=tokenizer)
        trainer.train()
        trainer.save_model(output_dir=save_model_path)

    elif data_args.downstream_task == "math_reasoning":
        gsm8k_raw = load_dataset("gsm8k", "main")
        gsm8k_train = gsm8k_raw["train"].map(gsm8k_build_train_text, remove_columns=gsm8k_raw["train"].column_names)

        trainer = SFTTrainer(model=model, args=sft_config, train_dataset=gsm8k_train, processing_class=tokenizer)
        trainer.train()
        trainer.save_model(output_dir=save_model_path)

    elif data_args.downstream_task == "instruction_following":
        gptteacher_raw = DatasetInstruction(data_args)
        # gptteacher_raw.samples = random.Random(sft_config.seed).sample(gptteacher_raw.samples, k=8000)
        gptteacher_train = DatasetHF.from_list(gptteacher_raw.samples)

        trainer = SFTTrainer(model=model, args=sft_config, train_dataset=gptteacher_train, processing_class=tokenizer)
        trainer.train()
        trainer.save_model(output_dir=save_model_path)
    
    else:
        raise ValueError("data_args.downstream_task should be sentiment_analysis, math_reasoning or instruction_following")

if __name__ == "__main__":
    print("******************** User Training Starts! ********************")
    main()
    print("******************** User Training Finished! ********************")