# Modular, function-based fine-tuning script
import os
import torch
import numpy as np
import random
from easydict import EasyDict
import json
import matplotlib.pyplot as plt
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    Trainer, 
    TrainingArguments,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model, PeftModel
from argparse import ArgumentParser
from safeft.utils import configs, get_model_path, get_datasets, compute_similarity, get_logits
import gc
import torch.nn as nn
import copy
import safeft.utils as utils
from filter import *


def set_training_seed(seed):
    """Set random seeds for training reproducibility (does not affect dataset generation)."""
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # For deterministic behavior in PyTorch
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print(f"Training random seed set to {seed} for reproducible training (dataset unchanged)")


def parse_args_and_config():
    parser = ArgumentParser()
    parser.add_argument("--base_model_path", type=str, default=None, help="Base model path")
    parser.add_argument("--model_path_name", type=str, default=None, help="Model path name (nickname) for results")
    parser.add_argument("--output_dir", type=str, default=None, help="Output directory for finetune results")
    parser.add_argument("--utility_training_num", type=int, default=None, help="Utility training samples")
    parser.add_argument("--poison_training_num", type=int, default=None, help="Poison training samples")
    parser.add_argument("--name_suffix", type=str, default="", help="Name suffix for output")
    parser.add_argument("--save_steps", type=int, default=100, help="Save every N steps")
    parser.add_argument("--random_seed", type=int, default=0, help="Random seed for reproducibility")

    parser.add_argument("--use_lora", type=int, default=1, help="Use LoRA")
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
    parser.add_argument("--lora_target_modules", type=str, nargs='+', default=["q_proj","v_proj"], help="LoRA target modules")
    parser.add_argument("--lora_initialization", type=str, default=None, help="LoRA initialization method")

    #dataset configurations
    parser.add_argument("--use_cache", type=int, default=1, help="Use cache for dataset loading")
    parser.add_argument("--utility_dataset_config", type=str, default="utility_dataset_config_samsum", help="Utility dataset config name")
    parser.add_argument("--poison_dataset_config", type=str, default="poison_dataset_config_LATharm", help="Poison dataset config name")
    parser.add_argument("--identity_shift_config", type=str, default="identity_shift_config", help="Identity shift dataset config name")
    parser.add_argument("--identity_shift_num", type=int, default=0, help="Number of identity shift samples")

    #baseline parameters
    parser.add_argument("--backdoor", type=int, default=0, help="Whether add backdoor 11 prefixed safety data.")
    parser.add_argument("--safety_alpaca_num", type=int, default=0, help="Number of safety alpaca samples for SafeInstr.")

    #training parameters
    parser.add_argument("--evaluation_strategy", type=str, default="epoch", help="Evaluation strategy")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=2, help="Per device train/eval batch size")
    parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="Gradient accumulation steps")
    parser.add_argument("--bf16", type=int, default=1, help="BF16 training")
    parser.add_argument("--use_device_map", type=int, default=1, help="Use device map for model loading")

    #Filtering parameters
    parser.add_argument("--filter_type", type=str, default="gradient", help="Filter type")
    parser.add_argument("--filter", type=int, default=0, help="Whether to apply gradient filtering")
    parser.add_argument("--filter_times", type=int, default=1, help="number of times to apply gradient filtering")
    parser.add_argument("--filter_threshold", type=float, default=0.8, help="Threshold for gradient filtering")
    parser.add_argument("--num_reference", type=int, default=1, help="Number of reference samples for gradient filtering")

    #data random seed
    parser.add_argument("--data_random_seed", type=int, default=None, help="Data random seed")

    #reference parameters
    parser.add_argument("--reference_alpaca", type=int, default=0, help="Whether to use alpaca format for reference")

    parser_args = parser.parse_args()

    args = EasyDict(configs)
    
    args.use_lora = parser_args.use_lora
    args.backdoor = parser_args.backdoor
    args.num_reference = parser_args.num_reference
    args.identity_shift_num = parser_args.identity_shift_num
    args.poison_training_num = parser_args.poison_training_num
    args.utility_training_num = parser_args.utility_training_num
    args.poison_dataset_config = copy.deepcopy(getattr(args, parser_args.poison_dataset_config))
    args.utility_dataset_config = copy.deepcopy(getattr(args, parser_args.utility_dataset_config))
    args.identity_shift_config = copy.deepcopy(getattr(args, parser_args.identity_shift_config))
    if parser_args.data_random_seed is not None:
        args.utility_dataset_config.random_seed = parser_args.data_random_seed
        args.poison_dataset_config.random_seed = parser_args.data_random_seed
        args.identity_shift_config.random_seed = parser_args.data_random_seed
    # Note: Dataset random seeds are NOT overridden by training seed to keep datasets consistent
    if parser_args.use_cache==0:
        # dont use cache for dataset loading, create new split but not save
        args.utility_dataset_config.split_id_list_path = None
        args.poison_dataset_config.split_id_list_path = None
    #save paths for models
    if parser_args.base_model_path is not None:
        args.base_model_path = parser_args.base_model_path
    if parser_args.model_path_name is not None:
        args.model_path_name = parser_args.model_path_name
    if parser_args.base_model_path is not None and parser_args.model_path_name is None:
        args.model_path_name = parser_args.base_model_path.split("/")[-1]
    if parser_args.output_dir is not None:
        args.output_dir = parser_args.output_dir
    args.name_suffix = parser_args.name_suffix
    if parser_args.utility_training_num is not None:
        args.utility_dataset_config.train_num = parser_args.utility_training_num
    if parser_args.poison_training_num is not None:
        args.poison_dataset_config.train_num = parser_args.poison_training_num
    if parser_args.identity_shift_num is not None:
        args.identity_shift_config.train_num = parser_args.identity_shift_num
    return args, parser_args


def load_model_and_tokenizer(args):
    if args.use_device_map:
        model = AutoModelForCausalLM.from_pretrained(
            args.base_model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.base_model_path,
            torch_dtype=torch.bfloat16,
        )
        model = model.to("cuda")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer


def setup_lora(base_model, args):
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.lora_target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()
    return model


def train(model, tokenizer, train_dataset, val_dataset, args, parser_args):
    zero_output_dir = get_model_path(args, father_path=args.output_dir)+'/checkpoint-0'
    # only save checkpoint 0 if use lora
    if parser_args.use_lora:
        if not os.path.exists(zero_output_dir):
            os.makedirs(zero_output_dir)
        model.save_pretrained(zero_output_dir)
    print(f"save steps: {parser_args.save_steps}")
    training_args = TrainingArguments(
        output_dir=get_model_path(args, father_path=args.output_dir),
        evaluation_strategy=parser_args.evaluation_strategy,
        save_strategy="steps",
        save_steps=parser_args.save_steps,
        learning_rate=parser_args.learning_rate,
        per_device_train_batch_size=parser_args.batch_size,
        per_device_eval_batch_size=parser_args.batch_size,
        weight_decay=parser_args.weight_decay,
        num_train_epochs=parser_args.epochs,
        logging_dir="./logs",
        gradient_accumulation_steps=parser_args.gradient_accumulation_steps,
        bf16=bool(parser_args.bf16),
        gradient_checkpointing=False,
        ddp_find_unused_parameters=False,
        label_names=["labels"],
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )
    trainer.train()
    return

def main():
    args, parser_args = parse_args_and_config()
    # if training is done already, skip everything
    output_dir = get_model_path(args, father_path=args.output_dir)
    if os.path.exists(output_dir) and len(os.listdir(output_dir)) >= 1 + parser_args.use_lora + parser_args.filter:
        print(f"Training already done. Output directory {output_dir} exists and is not empty.")
        return
    base_model, tokenizer = load_model_and_tokenizer(parser_args)
    train_dataset, val_dataset,train_formatted,val_formatted = get_datasets(args, tokenizer)
    # save train_formatted
    os.makedirs(output_dir, exist_ok=True)
    set_training_seed(parser_args.random_seed)

    if parser_args.use_lora:
        print("Use LoRA model for training")
        model = setup_lora(base_model, parser_args)
    else:
        print("Use full model for training")
        model = base_model
    if parser_args.filter:
        print("Applying gradient filtering to the training dataset.")
        score_save_path = os.path.join(parser_args.output_dir, parser_args.base_model_path.split("/")[-1], "scores.json")
        os.makedirs(os.path.dirname(score_save_path), exist_ok=True)
        if os.path.exists(score_save_path):
            with open(score_save_path, "r") as f:
                save_dict = json.load(f)
            scores = [y['score'] for y in save_dict]
            print(f"Loaded {len(train_dataset)} samples from {score_save_path}")
            
        else:
            scores = gradient_scores(model, tokenizer, train_dataset, args)
            with open(score_save_path, "w") as f:
                save_dict = [{'text': x['text'], 'score': scores[i]} for i, x in enumerate(train_dataset)]
                json.dump(save_dict, f)
        
        mask = gmm_filter(scores)
        train_dataset = train_dataset.select(np.where(mask)[0].tolist())
        print(f"filtering applied. {np.sum(mask)} samples retained out of {len(mask)} total samples.")
        # filter another time
        if parser_args.filter_times >= 2:
            score_save_path_2 = os.path.join(parser_args.output_dir, parser_args.base_model_path.split("/")[-1], "scores2.json")
            if os.path.exists(score_save_path_2):
                with open(score_save_path_2, "r") as f:
                    save_dict = json.load(f)
                scores = [y['score'] for y in save_dict]
                print(f"Loaded {len(train_dataset)} samples from {score_save_path_2}")
            else:
                scores = gradient_scores(model, tokenizer, train_dataset, args)
                with open(score_save_path_2, "w") as f:
                    save_dict = [{'text': x['text'], 'score': scores[i]} for i, x in enumerate(train_dataset)]
                    json.dump(save_dict, f)
            
            mask = gmm_filter(scores)
            train_dataset = train_dataset.select(np.where(mask)[0].tolist())
            print(f"filtering applied for the second time. {np.sum(mask)} samples retained out of {len(mask)} total samples.")
        model.train()
    else:
        print("Skipping filtering.")
    train(model, tokenizer, train_dataset, val_dataset, args, parser_args)
    del model
    del tokenizer
    torch.cuda.empty_cache()
    gc.collect()


if __name__ == "__main__":
    main()