import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import argparse
import deepspeed
import os
import json
import copy
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments
from obfuscate import arrowcloak, ours, shadownet, soter, translinkguard, tsqp, groupcover, coreguard
from tqdm import tqdm
import torch.distributed as dist
import random
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

class MultilabelDataCollator(DataCollatorWithPadding):
    def __call__(self, features):
        batch = super().__call__(features)
        
        if 'labels' in batch:
            if isinstance(batch['labels'], list):
                batch['labels'] = torch.tensor(batch['labels'], dtype=torch.float32)
            else:
                batch['labels'] = batch['labels'].to(dtype=torch.float32)
        return batch

def eval(model, data_loader):
    model.eval()  
    
    all_predictions = []
    all_labels = []
    
    if torch.distributed.get_rank() == 0:
        pbar = tqdm(total=len(data_loader), desc="eval...")

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(f"cuda:{torch.distributed.get_rank()}") for k, v in batch.items()}
            labels = batch['labels'].to(f"cuda:{torch.distributed.get_rank()}")
            inputs = {key: value for key, value in batch.items() if key != 'labels'}

            outputs = model(**inputs)
            logits = outputs.logits
            
            probabilities = torch.sigmoid(logits)
            predictions = (probabilities >= 0.5).float()
            
            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
            if torch.distributed.get_rank() == 0:
                pbar.update(1)

    all_predictions = np.vstack(all_predictions)
    all_labels = np.vstack(all_labels)
    
    predictions_tensor = torch.tensor(all_predictions, device=torch.cuda.current_device())
    labels_tensor = torch.tensor(all_labels, device=torch.cuda.current_device())
    
    world_size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    
    if world_size > 1:
        predictions_list = [torch.zeros_like(predictions_tensor) for _ in range(world_size)]
        torch.distributed.all_gather(predictions_list, predictions_tensor)
        all_predictions_distributed = torch.cat(predictions_list, dim=0).cpu().numpy()
        
        labels_list = [torch.zeros_like(labels_tensor) for _ in range(world_size)]
        torch.distributed.all_gather(labels_list, labels_tensor)
        all_labels_distributed = torch.cat(labels_list, dim=0).cpu().numpy()
    else:
        all_predictions_distributed = all_predictions
        all_labels_distributed = all_labels
    
    if rank == 0:
        micro_f1 = f1_score(all_labels_distributed, all_predictions_distributed, average='micro', zero_division=0)
        macro_f1 = f1_score(all_labels_distributed, all_predictions_distributed, average='macro', zero_division=0)
        samples_f1 = f1_score(all_labels_distributed, all_predictions_distributed, average='samples', zero_division=0)
        
        precision = precision_score(all_labels_distributed, all_predictions_distributed, average='micro', zero_division=0)
        recall = recall_score(all_labels_distributed, all_predictions_distributed, average='micro', zero_division=0)
        
        accuracy = accuracy_score(all_labels_distributed, all_predictions_distributed)
        
        result = {
            'micro_f1': micro_f1,
            'macro_f1': macro_f1,
            'samples_f1': samples_f1,
            'precision': precision,
            'recall': recall,
            'accuracy': accuracy,
        }
    else:
        result = None
    
    if torch.distributed.get_rank() == 0:
        pbar.close()
    
    return result

def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed ZeRO')

    # which model you tend to attack
    parser.add_argument('--pretrained_model_path', type=str, required=True, help='pretrained model path.')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--recover_data_path', type=str, required=True, help='Path to the recover dataset.')
    parser.add_argument('--test_data_path', type=str, required=True, help='Path to the test dataset.')

    # typical params
    parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=8, help='batch size per gpu')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32, help='gradient accumulation steps')
    parser.add_argument('--max_lr', type=float, default=1e-3, help='max learning rate')
    parser.add_argument('--initial_lr', type=float, default=1e-6, help='initial learning rate')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='min learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
    parser.add_argument('--adam_beta1', type=float, default=0.9, help='adam beta1')
    parser.add_argument('--adam_beta2', type=float, default=0.999, help='adam beta2')
    parser.add_argument('--fused', action='store_true', help='whether to use fused optimizer, if you can load all prameters of the model on a single gpu.')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--output_dir', type=str, default='./checkpoints/',help='save dir')

    # log dir of tensorboard
    parser.add_argument('--log_dir', type=str, default='./logs', help='log dir')

    # which fintuning method to use
    parser.add_argument('--finetune_method', type=str, default='lora', help='finetune method, support parameters lora, freeze, full-tuning')

    # freeze modules
    parser.add_argument('--freeze_modules', type=str, default='dense_h_to_4h', help='the layer of model you wanna freeze')

    # LoRA params
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='dropout probability for LoRA')
    parser.add_argument('--lora_target_modules', type=str, default='query_key_value', help='target modules for LoRA, the name of layer in model you wanna use LoRA')
    parser.add_argument('--lora_r', type=int, default=8, help='r for LoRA')

    # deepspeed params
    parser.add_argument('--ds_config_path', type=str, default='./config/ds_config.json', help='path to deepspeed config file')
    parser.add_argument('--offload_device', type=str, default='cpu', help='offload device, cpu or nvme, which mean you want to offload the model to cpu memory or nvme ssd')
    parser.add_argument('--nvme_path', type=str, default='./mnt/nvme', help='path to nvme ssd')
    parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
    parser.add_argument('--global_rank', type=int, default=-1, help='global rank passed from distributed launcher')
    parser = deepspeed.add_config_arguments(parser)

    # obfuscated model params
    parser.add_argument('--obfus_method', default='ours', type=str, choices=['black-box', 'random', 'ours', 'coreguard', 'arrowcloak', 'shadownet', 'soter', 'translinkguard', 'tsqp', 'groupcover'], help='The obfuscating method.')
    parser.add_argument('--victim_model_path', type=str, required=True, help='victim model path.')
    parser.add_argument('--block_size', type=int, default=2, help='block size.')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()
    
    seed = 42
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # init deepspeed
    if args.local_rank == -1:
        r"""
        when you don't use deepspeed to train your own model with a single gpu, you need to modify the 
        train loop according to pytorch traditional grammer. when arg.local_rank == -1, mean you are using 
        pytorch to train your model with a single gpu.
        """
        device = torch.device("cuda")
    else:
        device = torch.device("cuda", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        deepspeed.init_distributed()
        torch.distributed.barrier()
    master_process = (args.local_rank == 0)
    
    # get the tokenizer
    if args.obfus_method == 'black-box':
        tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.victim_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    def tokenize_function(examples):
        tokenized = tokenizer(
            examples['text'],
            truncation=True,
            padding=True,
            max_length=args.src_len,
            return_tensors="pt"
        )
        
        labels = torch.zeros((len(examples['labels']), 28), dtype=torch.float32)
        for i, label_indices in enumerate(examples['labels']):
            labels[i, label_indices] = 1.0
        
        tokenized['labels'] = labels

        return tokenized
    
    eval_dataset = load_dataset(args.test_data_path)["validation"]
    test_dataset = load_dataset(args.test_data_path)["test"]
    recover_dataset = load_dataset("json", data_files=args.recover_data_path)["train"]
    # preprocess dataset
    tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=eval_dataset.column_names)
    tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True, remove_columns=test_dataset.column_names)
    # get the dataloader
    data_collator = MultilabelDataCollator(tokenizer=tokenizer,
        padding='max_length',
        max_length=args.src_len, 
        return_tensors="pt"
    )
    test_sampler = DistributedSampler(eval_dataset, shuffle=False)
    test_loader = DataLoader(
        tokenized_test_dataset,
        batch_size=16, 
        collate_fn=data_collator,
        shuffle=False,
        sampler=test_sampler 
    )
    
    def random_model(model):
        for param in model.parameters():
            if param.requires_grad:
                std = torch.std(param.data).item()
                mean = torch.mean(param.data).item()
                param.data = torch.randn_like(param.data) * std + mean
    
    obfus_method_map = {
        'ours': ours.obfuscate_model,
        'arrowcloak': arrowcloak.obfuscate_model,
        'shadownet': shadownet.obfuscate_model,
        'soter': soter.obfuscate_model,
        'translinkguard': translinkguard.obfuscate_model,
        'coreguard': coreguard.obfuscate_model,
        'tsqp': tsqp.obfuscate_model,
        'groupcover': groupcover.obfuscate_model,
        'random': random_model
    }
    attack_fn_map = {
        'ours': ours.attack,
        'arrowcloak': arrowcloak.attack,
        'shadownet': shadownet.attack,
        'soter': soter.attack,
        'translinkguard': translinkguard.attack,
        'coreguard': coreguard.attack,
        'tsqp': tsqp.attack,
        'black-box': lambda model, pre_model: model,
        'random': lambda model, pre_model: model,
        'groupcover': groupcover.attack,
        # 'groupcover': lambda model, pre_model: model,
    }
    
    if args.obfus_method == 'black-box':
        # using the pretained model to attack
        vic_model = AutoModelForSequenceClassification.from_pretrained(args.pretrained_model_path, num_labels=28, problem_type="multi_label_classification")
        vic_model.config.pad_token_id = tokenizer.pad_token_id
        vic_model = vic_model.to(device)
        vic_result = eval(vic_model, test_loader)
    else:
        vic_model = AutoModelForSequenceClassification.from_pretrained(args.victim_model_path, num_labels=28, problem_type="multi_label_classification")
        vic_model = vic_model.to(device)
        vic_result = eval(vic_model, test_loader)
        vic_model = vic_model.to("cpu")
        if args.obfus_method == 'ours':
            obfus_method_map[args.obfus_method](vic_model, block_size=args.block_size)
        else:
            obfus_method_map[args.obfus_method](vic_model)
    obf_model = vic_model
    obf_model.config.pad_token_id = tokenizer.pad_token_id
    
    pretrained_model = AutoModelForSequenceClassification.from_pretrained(args.pretrained_model_path, num_labels=28, problem_type="multi_label_classification")
    obf_model = obf_model.to(device)
    if args.obfus_method == 'black-box':
        obfus_result = vic_result
    else:
        obfus_result = eval(obf_model, test_loader)
    
    obf_model = attack_fn_map[args.obfus_method](obf_model, pretrained_model).to(device)
    attack_result = eval(obf_model, test_loader)
    
    # get model, and you can add other finetuning methods here, eg: prefix tuning, P-tuning, etc.
    if args.finetune_method == "full-tuning":
        pass
    elif args.finetune_method == "lora":
        lora_target_modules = args.lora_target_modules.split(",")
        print(lora_target_modules)
        lora_config = LoraConfig(
            # task_type=TaskType.CAUSAL_LM,
            task_type=TaskType.SEQ_CLS,
            r=args.lora_r, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            target_modules=lora_target_modules,
            inference_mode=False
            )
        obf_model = get_peft_model(obf_model, lora_config)
    elif args.finetune_method == "freeze":
        freeze_modules = args.freeze_modules.split(",")
        for name, param in obf_model.named_parameters():
            if any(item in name for item in freeze_modules): 
                param.requires_grad = False
    else:
        raise ValueError("Invalid finetune method")
    
    if args.fused:
        obf_model = obf_model.to(device)

    # print the number of the model parameters
    if master_process:
        print(obf_model)
        print(f"Local rank: {args.local_rank} \n Total number of parameters: {sum(p.numel() for p in obf_model.parameters())}")

    # load the deepspeed config json file
    ds_config = json.load(open(args.ds_config_path))
    ds_config['train_micro_batch_size_per_gpu'] = args.train_micro_batch_size_per_gpu
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['zero_optimization']['offload_param']['device'] = args.offload_device
    ds_config['zero_optimization']['offload_optimizer']['device'] = args.offload_device
    if args.offload_device == "nvme":
        # here don't use os.path.mkdir, because it may cover the original directory.
        if not os.path.exists(args.nvme_path):
            raise ValueError(f"nvme path does not exist, please make directory {args.nvme_path} by yourself.")
        else:
            ds_config['zero_optimization']['offload_param']['nvme_path'] = args.nvme_path
            ds_config['zero_optimization']['offload_optimizer']['nvme_path'] = args.nvme_path

    if master_process:        
        print(f"DeepSpeed config: {ds_config}")
    
    # train
    training_args = TrainingArguments(
        output_dir=args.output_dir, 
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_micro_batch_size_per_gpu,
        per_device_eval_batch_size=args.train_micro_batch_size_per_gpu,  
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.max_lr, 
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        warmup_ratio=0.1 if args.initial_lr < args.max_lr else 0,
        logging_dir=args.log_dir,
        logging_steps=1,  
        label_names=["labels"],
        save_strategy="steps",
        save_steps=5, 
        eval_strategy="steps",
        eval_steps=5,
        save_total_limit=2,
        metric_for_best_model="micro_f1",
        load_best_model_at_end=True,
        deepspeed=ds_config,
        fp16=ds_config.get("fp16", {}).get("enabled", False),
        bf16=ds_config.get("bf16", {}).get("enabled", False),
    )

    def compute_metrics(p):
        """计算多标签分类的评估指标"""
        try:
            preds = p.predictions
            if isinstance(preds, tuple):
                preds = preds[0]
            
            # 使用sigmoid将logits转换为概率，然后应用阈值
            probs = torch.sigmoid(torch.tensor(preds)).numpy()
            y_pred = (probs >= 0.5).astype(np.int32)  # 直接使用布尔索引，更高效
            
            y_true = p.label_ids.astype(np.int32)  # 确保标签是整型
            
            # 计算多种评估指标[6,8](@ref)
            micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
            macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
            samples_f1 = f1_score(y_true, y_pred, average='samples', zero_division=0)
            
            precision = precision_score(y_true, y_pred, average='samples', zero_division=0)
            recall = recall_score(y_true, y_pred, average='samples', zero_division=0)
            accuracy = accuracy_score(y_true, y_pred)
            
            return {
                'micro_f1': micro_f1,
                'macro_f1': macro_f1,
                'samples_f1': samples_f1,
                'precision': precision,
                'recall': recall,
                'accuracy': accuracy,
            }
        except Exception as e:
            print(f"计算评估指标时出错: {e}")
            # 返回默认值避免训练中断
            return {
                'micro_f1': 0.0,
                'macro_f1': 0.0,
                'samples_f1': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'accuracy': 0.0,
            }

    recover_trainer = Trainer(
        model=obf_model,
        args=training_args,
        train_dataset=recover_dataset, 
        eval_dataset=tokenized_eval_dataset,
        data_collator=data_collator, 
        compute_metrics=compute_metrics
    )
    recover_trainer.train()
    
    restore_results = eval(obf_model, test_loader)
    if torch.distributed.get_rank() == 0:
        print(f"result before obfuscating: {vic_result}")
        print(f"result after obfuscating: {obfus_result}")
        print(f"the result after attack but before training: {attack_result}")
        print(f"the result after recovering:{restore_results}")
        with open('goemotions.txt', 'a', encoding='utf-8') as f:
            f.write("\n#################")
            f.write("\ndataset: goemotions")
            f.write(f"\npretrained model path: {args.pretrained_model_path}")
            f.write(f"\nobfuscate method: {args.obfus_method}")
            f.write(f"\n the result before obfsucating: {vic_result}")
            f.write(f"\n the result after obfsucating: {obfus_result}")
            f.write(f"\n the result after attack but before training: {attack_result}")
            f.write(f"\n the result after recovering: {restore_results}")


    
if __name__ == "__main__":
    main()