import gc
import torch
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from typing import List, Dict, Any, Union
from sentence_transformers import util
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, BitsAndBytesConfig
import hashlib
import math
from datasets import load_dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import _LRScheduler
import json
from sklearn.decomposition import PCA
from torch.nn.utils import parameters_to_vector
from Unlearning_algorithm.GA import Gradient_ascent 
from Unlearning_algorithm.RL import Random_label
from Unlearning_algorithm.NPO_KL import NPO_KL
from Unlearning_algorithm.GA_GD import Gradient_ascent_descent
from Unlearning_algorithm.GA_KL import Gradient_ascent_KL 
from Unlearning_algorithm.NPO import NPO
from Unlearning_algorithm.Fine_tune import Fine_tune
import os
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Unlearning Experiment Configuration")
    parser.add_argument('--N', type=int, default=20000, help='Number of training steps or examples')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum input sequence length')
    parser.add_argument('--request_num', type=int, default=3, help='Number of unlearning requests')
    parser.add_argument('--lr', type=float, default="1e-6", help='Learning rate for the optimizer')
    parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
    return parser.parse_args()

def read_json_as_list(file_path):
    """
    Read a JSON file and return it as a list.
    If the root object of the JSON is not a list, wrap it as a list.
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    else:
        return [data]

# Define a custom learning rate scheduler
class CosineAnnealingWithWarmupLR(_LRScheduler):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, eta_min=0.0, last_epoch=-1):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.eta_min = eta_min
        super(CosineAnnealingWithWarmupLR, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        current_step = max(0, self.last_epoch)
        if current_step < self.num_warmup_steps:
            # Warm-up phase
            lr_scale = float(current_step) / float(max(1, self.num_warmup_steps))
        else:
            # Cosine annealing phase
            progress = float(current_step - self.num_warmup_steps) / float(max(1, self.num_training_steps - self.num_warmup_steps))
            cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
            lr_scale = (1 - self.eta_min) * cosine_decay + self.eta_min
        return [base_lr * lr_scale for base_lr in self.base_lrs]

class FineTuning:
    def __init__(self, model, infer_model, tokenizer, lr=4e-6, device='cuda:1'):
        self.device = model.device
        self.tokenizer = tokenizer
        self.model = model
        self.infer_model = infer_model
        self.lr = lr

    def run_finetuning(self, unlearning_algorithm, retain_set: list[str], forget_data: str, epochs: int = 20, chunk_size: int = 4096):
        """
        Fine-tune the model based on the selected unlearning method.
        :param forget_data: Text data to be forgotten
        :param epochs: Number of training epochs
        :param chunk_size: Chunk size
        """
        self.model.train()
        optimizer = AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1)  
        
        data_length = len(forget_data)
        num_chunks = math.ceil(data_length / chunk_size)
        total_steps = epochs * num_chunks
        warmup_steps = int(0.1 * total_steps)
        
        scheduler = CosineAnnealingWithWarmupLR(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
            eta_min=0.1
        )

        if unlearning_algorithm == "GA":
            Gradient_ascent(self.model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "GA_GD":
            Gradient_ascent_descent(self.model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "GA_KL":
            Gradient_ascent_KL(self.model, self.infer_model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "NPO_KL":
            NPO_KL(self.model, self.infer_model, self.tokenizer, forget_data, retain_set, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "RL":
            Random_label(self.model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)
        elif unlearning_algorithm == "NPO":
            NPO(self.model, self.infer_model, self.tokenizer, forget_data, optimizer, scheduler, epochs, chunk_size, self.device)

        print("[FineTuning] Fine-tuning complete.")

# ======= Main Pipeline: Full Example =======
class UnlearningPipeline:
    def __init__(self,
                 target_name="01-ai/Yi-6B",
                 forget_data: str = None,
                 retain_data: List[str] = None,
                 device='cuda:1'):
        self.retain_acc_list = []
        self.retain_ppl_list = []
        self.forget_acc_list = []
        self.forget_ppl_list = []
        self.forget_list = forget_data
        self.retain_samples = 100
        # Step 2: Load Target Model (perform actual unlearning)
        print(f"[UnlearningPipeline] Loading Target Model: {target_name}")
        self.target_tokenizer = AutoTokenizer.from_pretrained(target_name)
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_name,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_flash_attention_2=True,
            device_map="auto"
        )
        self.infer_model = AutoModelForCausalLM.from_pretrained(
            target_name,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_flash_attention_2=True,
            device_map="auto"
        )
        if self.target_tokenizer.pad_token is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
            self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
        self.device = self.target_model.device

    def unlearning_request_handler(self, user_id: str, Unlearning_algorithm: str, x_forget: str, x_retain: list[str], max_length, lr):
        print(f"\n[UnlearningPipeline] Received unlearning request from user={user_id}")
        print("[Pipeline] Start FineTuning ...")
        fine_tuner = FineTuning(self.target_model, self.infer_model, self.target_tokenizer, lr, device=self.target_model.device)
        fine_tuner.run_finetuning(Unlearning_algorithm, x_retain, x_forget, epochs=3, chunk_size=8192)
        self.target_model.eval()
        return self.retain_acc_list, self.retain_ppl_list, self.forget_acc_list, self.forget_ppl_list, self.target_model, self.target_tokenizer

# ====================== Test Example ======================
if __name__ == "__main__":
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")

    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]

    forget_list = forget_list_arxiv + forget_list_github
    retain_list = retain_list_arxiv + retain_list_general + retain_list_github

    # Set random seed (for reproducibility)
    random.seed(42)
    # Shuffle the list
    random.shuffle(forget_list)
    random.shuffle(retain_list)

    args = parse_args()
    N = args.N
    max_length = args.max_length
    request_num = args.request_num   
    lr = args.lr

    # Single unlearning
    forget_list = [" ".join(forget_list[i:i+N]) for i in range(0, len(forget_list), N)]
    Unlearning_model = [args.model_name]
    Unlearning_model_file_name = [Unlearning_model[0].split('/')[-1]]
    Unlearning_algorithm = ["GA", "GA_GD", "GA_KL", "NPO_KL", "RL", "NPO"]

    # Initialize main pipeline
    for k in range(len(Unlearning_model)):
        for j in range(len(Unlearning_algorithm)):
            pipeline = UnlearningPipeline(
                target_name=Unlearning_model[k],
                forget_data=forget_list[0:100],
                retain_data=retain_list
            )

            # Submit unlearning request
            for i in range(request_num):
                user_id = "User" + str(i)
                retain_acc_list, retain_ppl_list, forget_acc_list, forget_pll_list, model, tokenizer = pipeline.unlearning_request_handler(
                    user_id, Unlearning_algorithm[j], forget_list[i], retain_list, max_length, lr)

            # Save model weights
            model_save_path = f"Model/Text/all_layer/{Unlearning_model_file_name[k]}/lr{lr}_{Unlearning_algorithm[j]}_N{N}_Request{request_num}"
            if not os.path.exists(model_save_path):
                os.makedirs(model_save_path)
            model.save_pretrained(model_save_path)
            tokenizer.save_pretrained(model_save_path) 

            del pipeline
            del model
            del tokenizer

            # Force garbage collection
            gc.collect()
            # Clear PyTorch GPU cache
            torch.cuda.empty_cache()
            # Clear PyTorch memory allocator
            torch.cuda.ipc_collect()
