import gc
import torch
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from typing import List, Dict, Any, Union
from sentence_transformers import util
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import hashlib
import math
import os
from datasets import load_dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import _LRScheduler
import lm_eval
import json
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Unlearning Experiment Configuration")

    parser.add_argument('--N', type=int, default=20000, help='Number of training steps or examples')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum input sequence length')
    parser.add_argument('--request_num', type=int, default=3, help='Number of unlearning requests')
    parser.add_argument('--type', type=str, default="all", help='Type of chosen layers')
    parser.add_argument('--lr', type=float, default="1e-6", help='Learning rate for the optimizer')
    parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
    parser.add_argument('--data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
    return parser.parse_args()

def read_json_as_list(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    else:
        return [data]

# Define a custom learning rate scheduler
class CosineAnnealingWithWarmupLR(_LRScheduler):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, eta_min=0.0, last_epoch=-1):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.eta_min = eta_min
        super(CosineAnnealingWithWarmupLR, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        current_step = max(0, self.last_epoch)
        if current_step < self.num_warmup_steps:
            # Warm-up phase
            lr_scale = float(current_step) / float(max(1, self.num_warmup_steps))
        else:
            # Cosine annealing phase
            progress = float(current_step - self.num_warmup_steps) / float(max(1, self.num_training_steps - self.num_warmup_steps))
            cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
            lr_scale = (1 - self.eta_min) * cosine_decay + self.eta_min

        return [base_lr * lr_scale for base_lr in self.base_lrs]
    
class FineTuning:
    def __init__(self, model, tokenizer, lr, device):
        self.device = device
        self.tokenizer = tokenizer
        self.model = model
        self.lr = lr

    def run_finetuning(self, forget_data: str, epochs: int = 20, chunk_size: int = 4096):
        self.model.train()
        optimizer = AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1)  
        
        data_length = len(forget_data)
        num_chunks = math.ceil(data_length / chunk_size)
        total_steps = epochs * num_chunks
        warmup_steps = int(0.1 * total_steps)
        
        scheduler = CosineAnnealingWithWarmupLR(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
            eta_min=0.1
        )

        for epoch in range(epochs):
            print(f"\n[FineTuning] Epoch {epoch + 1} ...")
            optimizer.zero_grad()
            total_loss = 0.0
            for chunk_idx in range(num_chunks):
                start_idx = chunk_idx * chunk_size
                end_idx = start_idx + chunk_size
                chunk_text = forget_data[start_idx:end_idx]

                inputs = self.tokenizer(
                    chunk_text,
                    return_tensors="pt",
                    truncation=True,
                    max_length=4096,
                    padding=True
                ).to(self.device)

                input_ids = inputs["input_ids"]
                attention_mask = inputs["attention_mask"]
                labels = input_ids.masked_fill(attention_mask == 0, -100)

                outputs = self.model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    labels=labels
                )
                loss = outputs.loss
                weighted_loss = loss

                total_loss += weighted_loss.item()
                weighted_loss.backward()
                optimizer.step()
                scheduler.step()

            avg_loss = total_loss / num_chunks
            print(f"[Epoch {epoch + 1}] weighted_loss={avg_loss:.4f}")
        print("[FineTuning] Fine-tuning complete.")

# ======= Main pipeline: full example =======
class UnlearningPipeline:
    def __init__(self,
                 target_name="01-ai/Yi-6B",
                 forget_data: List[str] = None,
                 retain_data: List[str] = None,
                 device='cuda:1'):
        self.retain_acc_list=[]
        self.retain_ppl_list=[]
        self.forget_acc_list=[]
        self.forget_ppl_list=[]
        self.forget_list=forget_data
        self.retain_samples=100
        # 2) Load Target Model (which will be unlearned)
        print(f"[UnlearningPipeline] Loading Target Model: {target_name}")
        self.target_tokenizer = AutoTokenizer.from_pretrained(target_name)
        self.target_model = AutoModelForCausalLM.from_pretrained(target_name,       
                                                                torch_dtype=torch.bfloat16,
                                                                trust_remote_code= True,
                                                                use_flash_attention_2= True,
                                                                device_map="auto",
                                                            )
        if self.target_tokenizer.pad_token is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
            self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
        self.device = self.target_model.device

    def unlearning_request_handler(self, user_id: str, x_forget: str, x_retain: list[str], max_length, lr):
        print(f"\n[UnlearningPipeline] Received unlearning request from user={user_id}") 
        print("[Pipeline] Start FineTuning ...")
        fine_tuner = FineTuning(self.target_model, self.target_tokenizer, lr, device=self.device)
        fine_tuner.run_finetuning(x_forget, epochs=3, chunk_size=8192)

        # ========== Step 5: After fine-tuning get p_unlearned => Generate and verify Proof ==========
        self.target_model.eval()
        return self.retain_acc_list, self.retain_ppl_list, self.forget_acc_list, self.forget_ppl_list, self.target_model, self.target_tokenizer

# ====================== Test Example ======================
if __name__ == "__main__":
    # Assume we have some retain_data used to estimate the MIA threshold
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]
    forget_list = forget_list_arxiv + forget_list_github
    retain_list = retain_list_arxiv + retain_list_general + retain_list_github

    # Set random seed (for reproducibility)
    random.seed(42)
    # Shuffle lists
    random.shuffle(forget_list)
    random.shuffle(retain_list)

    args = parse_args()
    N = args.N
    max_length = args.max_length
    request_num = args.request_num   
    lr = args.lr
    type = args.type
    data_type = args.data_type

    # Relearning data type
    if data_type == "forget_set":
        forget_list = forget_list_arxiv + forget_list_github
    elif data_type == "retain_set":
        forget_list = retain_list_arxiv + retain_list_github
    elif data_type == "unrelated_set":
        forget_list = retain_list_general

    # Initialize main Pipeline
    # Single unlearning
    forget_list = [" ".join(forget_list[i:i+N]) for i in range(0, len(forget_list), N)]
    model_name = [args.model_name]
    Unlearning_model_file_name = [model_name[0].split('/')[-1]]
    Unlearning_algorithm = ["GA", "GA_GD", "GA_KL", "NPO_KL", "RL", "NPO"] 

    # Initialize main Pipeline
    Unlearning_model = [
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_GA_N{N}_Request{request_num}",
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_GA_GD_N{N}_Request{request_num}",
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_GA_KL_N{N}_Request{request_num}",
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_NPO_N{N}_Request{request_num}",
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_NPO_KL_N{N}_Request{request_num}",
        f"Model/Text/{type}_layer/{Unlearning_model_file_name[0]}/lr{lr}_RL_N{N}_Request{request_num}"
    ]

    for j in range(len(Unlearning_model)):
        pipeline = UnlearningPipeline(
            target_name=Unlearning_model[j],
            forget_data=forget_list[0:100],
            retain_data=retain_list
        )

        # Submit unlearning requests
        for i in range(request_num):
            user_id = "User" + str(i)
            retain_acc_list, retain_ppl_list, forget_acc_list, forget_pll_list, model, tokenizer = pipeline.unlearning_request_handler(user_id, forget_list[i], retain_list, max_length, lr)
        
        # Save model weights
        # Check if the folder exists, if not, create it
        if data_type == "forget_set":
            model_save_path = f"Model/recovery/Text/{Unlearning_model_file_name[0]}/{Unlearning_algorithm[j]}/lr{lr}_{type}_layers_forget_N{N}_Request{request_num}"
        else:
            model_save_path = f"Model/recovery/Text/{Unlearning_model_file_name[0]}/{Unlearning_algorithm[j]}_{data_type}/lr{lr}_{type}_layers_retain_N{N}_Request{request_num}"
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        model.save_pretrained(model_save_path)
        tokenizer.save_pretrained(model_save_path)         

        del pipeline
        del model
        del tokenizer
        # Force garbage collection
        gc.collect()
        # Clear PyTorch GPU cache
        torch.cuda.empty_cache()
        # Clear PyTorch IPC memory
        torch.cuda.ipc_collect()
