import gc
import torch
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from typing import List, Dict, Any, Union
from sentence_transformers import util
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, BitsAndBytesConfig
import hashlib
import math
from datasets import load_dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import _LRScheduler
import lm_eval
import json
from sklearn.decomposition import PCA
from torch.nn.utils import parameters_to_vector
from Unlearning_algorithm.GA import Gradient_ascent 
from Unlearning_algorithm.RL import Random_label
from Unlearning_algorithm.NPO_KL import NPO_KL
from Unlearning_algorithm.GA_GD import Gradient_ascent_descent
from Unlearning_algorithm.GA_KL import Gradient_ascent_KL 
from Unlearning_algorithm.NPO import NPO
from Unlearning_algorithm.Fine_tune import Fine_tune
import os
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Unlearning Experiment Configuration")

    parser.add_argument('--N', type=int, default=20000, help='Number of training steps or examples')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum input sequence length')
    parser.add_argument('--request_num', type=int, default=3, help='Number of unlearning requests')
    parser.add_argument('--lr', type=float, default="1e-6", help='Learning rate for the optimizer')
    parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
    return parser.parse_args()

# login() 
def read_json_as_list(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    else:
        return [data]

# Define a custom learning rate scheduler
class CosineAnnealingWithWarmupLR(_LRScheduler):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, eta_min=0.0, last_epoch=-1):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.eta_min = eta_min
        super(CosineAnnealingWithWarmupLR, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        current_step = max(0, self.last_epoch)
        if current_step < self.num_warmup_steps:
            # Warm-up phase
            lr_scale = float(current_step) / float(max(1, self.num_warmup_steps))
        else:
            # Cosine annealing phase
            progress = float(current_step - self.num_warmup_steps) / float(max(1, self.num_training_steps - self.num_warmup_steps))
            cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
            lr_scale = (1 - self.eta_min) * cosine_decay + self.eta_min

        return [base_lr * lr_scale for base_lr in self.base_lrs]
    
class FineTuning:
    def __init__(self, model, infer_model, tokenizer, lr=4e-6, device='cuda:1'):
        self.device = model.device
        self.tokenizer = tokenizer
        self.model = model
        self.infer_model = infer_model
        self.lr = lr

    def run_finetuning(self, unlearning_algorithm, retain_set: list[str], forget_data: str, epochs: int = 20, chunk_size: int = 4096):
        """
        Fine-tune the model based on the selected unlearning algorithm.
        :param forget_data: Text to be forgotten
        :param epochs: Number of training epochs
        :param chunk_size: Chunk size for input segmentation
        """
        self.model.train()

        optimizer = AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1)  
        
        data_length = len(forget_data)
        num_chunks = math.ceil(data_length / chunk_size)
        total_steps = epochs * num_chunks
        warmup_steps = int(0.1 * total_steps)
        
        scheduler = CosineAnnealingWithWarmupLR(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
            eta_min=0.1
        )

        if unlearning_algorithm == "GA":
            Gradient_ascent(self.model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "GA_GD":
            Gradient_ascent_descent(self.model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "GA_KL":
            Gradient_ascent_KL(self.model, self.infer_model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "NPO_KL":
            NPO_KL(self.model, self.infer_model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "RL":
            Random_label(self.model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "NPO":
            NPO(self.model, self.infer_model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)      

        print("[FineTuning] Fine-tuning complete.")

# ======= Main pipeline: full example =======
class UnlearningPipeline:
    def __init__(self,
                 target_name="01-ai/Yi-6B",
                 forget_data: str = None,
                 retain_data: List[str] = None,
                 device='cuda:1'):
        self.retain_acc_list = []
        self.retain_ppl_list = []
        self.forget_acc_list = []
        self.forget_ppl_list = []
        self.forget_list = forget_data
        self.retain_samples = 100
        # Step 2: Load the target model (to be unlearned)
        print(f"[UnlearningPipeline] Loading Target Model: {target_name}")
        self.target_tokenizer = AutoTokenizer.from_pretrained(target_name)
        self.target_model = AutoModelForCausalLM.from_pretrained(target_name,
                                                                torch_dtype=torch.bfloat16,
                                                                trust_remote_code=True,
                                                                use_flash_attention_2=True,
                                                                device_map="auto"
                                                            )
        self.infer_model = AutoModelForCausalLM.from_pretrained(target_name,
                                                                torch_dtype=torch.bfloat16,
                                                                trust_remote_code=True,
                                                                use_flash_attention_2=True,
                                                                device_map="auto"
                                                            ) 
        if self.target_tokenizer.pad_token is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
            self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
        self.device = self.target_model.device

    def unlearning_request_handler(self, user_id: str, Unlearning_algorithm: str, x_forget: str, x_retain: list[str], max_length, lr):
        print(f"\n[UnlearningPipeline] Received unlearning request from user={user_id}")  #, data={x_forget}
        print("[Pipeline] Start FineTuning ...")
        fine_tuner = FineTuning(self.target_model, self.infer_model, self.target_tokenizer, lr, device=self.target_model.device)
        fine_tuner.run_finetuning(Unlearning_algorithm, x_retain, x_forget, epochs=3, chunk_size=8192)

        self.target_model.eval()
        return self.retain_acc_list, self.retain_ppl_list, self.forget_acc_list, self.forget_ppl_list, self.target_model, self.target_tokenizer


# ====================== Test Example ======================
if __name__ == "__main__":
    dataset = load_dataset("AI-MO/NuminaMath-1.5")
    problems = dataset["train"]["problem"]
    solutions = dataset["train"]["solution"]
    answers = dataset["train"]["answer"]
    forget_list = [f"{p} {s} {a}" for p, s, a in zip(problems, solutions, answers)]
    args = parse_args()
    N = args.N
    forget_layer = args.forget_layer
    max_length = args.max_length
    request_num = args.request_num  
    lr = args.lr 
    Unlearning_model = [args.model_name]
    Unlearning_model_file_name = [Unlearning_model[0].split('/')[-1]]
    forget_list = [" ".join(problems[i:i+N]) for i in range(0, len(problems), N)]
    retain_list = ["I am xiaoyu xu"]
    # Set random seed (for reproducibility)
    random.seed(42)
    # Shuffle the list
    random.shuffle(forget_list)
    Unlearning_algorithm = ["RL", "NPO", "GA"]
    # Initialize main pipeline
    for k in range(len(Unlearning_model)):
        for j in range(len(Unlearning_algorithm)):
            pipeline = UnlearningPipeline(
                target_name=Unlearning_model[k],
                forget_data=forget_list[0:100],
                retain_data=retain_list,
            )

            # Submit unlearning requests
            for i in range(request_num):
                user_id = "User" + str(i)
                retain_acc_list, retain_ppl_list, forget_acc_list, forget_pll_list, model, tokenizer = pipeline.unlearning_request_handler(
                    user_id, Unlearning_algorithm[j], forget_list[i], retain_list, max_length, lr)

            # Save model weights
            model_save_path = f"Model/Math/all_layer/{Unlearning_model_file_name[k]}/lr{lr}_{Unlearning_algorithm[j]}_N{N}_Request{request_num}"
            if not os.path.exists(model_save_path):
                os.makedirs(model_save_path)
            model.save_pretrained(model_save_path)
            tokenizer.save_pretrained(model_save_path) 
            del pipeline
            del model
            del tokenizer
            # Force garbage collection
            gc.collect()
            # Clear PyTorch GPU cache
            torch.cuda.empty_cache()
            # Clear PyTorch memory allocator
            torch.cuda.ipc_collect()
