import openml
import pandas as pd 
import time
import matplotlib.pyplot as plt
import numpy as np 
from copy import deepcopy
from typing import Tuple
from itertools import chain 
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import Dataset, TensorDataset, DataLoader

from config import DATASET_ID



# # get dataset 
# def get_dataset_from_OPENML(dataset_name: str) -> pd.DataFrame:
#     dataset = openml.datasets.get_dataset(dataset_id=DATASET_ID[dataset_name])
    
#     df, _, _, _ = dataset.get_data(dataset_format='dataframe')
    
#     return df

def get_dataset_from_OPENML(dataset_name: str) -> pd.DataFrame:    
    df = pd.read_csv('openml_dataset_cache/{}_{}.csv'.format(dataset_name, DATASET_ID[dataset_name]), index_col=[0])
    
    return df



# preprocess numerical and nominial features and split training and testing
def preprocess_df(df: pd.DataFrame, numerical_cols: list, nominial_cols: list) -> Tuple[np.ndarray, np.ndarray]:
    # dropna 
    df = df.dropna()
    
    # get X and y 
    cols = df.columns 
    X = df.drop(columns=cols[-1])
    y = df[cols[-1]]
    
    # Preprocessing
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numerical_cols),
            ('cat', OneHotEncoder(drop='first'), nominial_cols)
        ]
    )
    X_processed = preprocessor.fit_transform(X)
    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y)
    # if isinstance(X, np.ndarray): X_processed = preprocessor.fit_transform(X)
    # else: X_processed = preprocessor.fit_transform(X).to_array()
    
    return X_processed, y
    
    
    



# build dataset class 
class TabularDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, device=torch.device):
        super(TabularDataset, self).__init__()
        self.X = torch.tensor(X, dtype=torch.float32).to(device)
        self.y = torch.tensor(y, dtype=torch.long).to(device)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx]
    
    


# build nn-based classifier
class TabularClassifier(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(TabularClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)  # Binary classification
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x: torch.Tensor):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)
    



# train classifier 
def train_classifier(X_train: np.ndarray, y_train: np.ndarray, output_dim: int, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[TabularClassifier, float, list]:
    # init dataset and dataloader 
    train_dataset = TabularDataset(X=X_train, y=y_train, device=device)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    
    # init classifier
    input_dim = X_train.shape[1]
    model = TabularClassifier(input_dim=input_dim, output_dim=output_dim).to(device)
    
    # init loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        model.train()
        running_loss = .0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")
        running_loss_lst.append(running_loss/len(train_loader))
    
    end_time = time.time()
    
    return model, end_time - start_time, running_loss_lst
        
        


# evaluate classification performance 
def evaluate_classifier(model: nn.Module, X_test: np.ndarray, y_test: np.ndarray, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    # init dataset and dataloader 
    test_dataset = TabularDataset(X=X_test, y=y_test, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    model.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true
        


# unlearn original model with our shuffling-based approach 
def train_unlearning_with_shuffle(UL_model: TabularClassifier, X_train: np.ndarray, y_train: np.ndarray, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[TabularClassifier, list, float]: 
    # init dataset and dataloader
    train_dataset = TabularDataset(X=X_train, y=y_train, device=device)
    # train_loader = DataLoader(dataset=train_dataset, batch_size=X_train.shape[0], shuffle=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    
    # init loss function and optimizer
    task_loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(UL_model.parameters(), lr=lr)
    
    
    task_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        UL_model.train() 
        running_task_loss = 0
        for inputs, labels in train_loader:
            inputs[:, 0:1] = inputs[:, 0:1][torch.randperm(inputs.shape[0])] # shuffle unlearned feature!!!
            outputs = UL_model(inputs) # get y_hat 
            task_loss = task_loss_func(outputs, labels) # get task loss 
            optimizer.zero_grad()
            task_loss.backward()
            optimizer.step()
            
            running_task_loss += task_loss.cpu().item()

            
        print(f"Epoch [{epoch+1}/{epochs}], Running Task Loss: {running_task_loss/len(train_loader):.4f},")
        task_loss_lst.append(running_task_loss/len(train_loader))
    
    end_time = time.time()
        

    return UL_model, task_loss_lst, end_time-start_time








# baseline 1: replaced unlearned feature with uniform distribution 
def BL1_prep_data(X_train: np.ndarray, model: TabularClassifier) -> Tuple[TabularClassifier, np.ndarray]:
    # first copy origin model
    BL1_model = deepcopy(model)
    
    # replace unlearned feature with uniform distribution from -1 to 1 (for categorical feature, we can also do the same operation)
    len_X = X_train.shape[0]
    new_ul_feature = np.random.uniform(low=-1, high=1, size=len_X)
    
    BL1_X_train = deepcopy(X_train)
    BL1_X_train[:, 0] = new_ul_feature
    
    return BL1_model, BL1_X_train


# baseline 2 (NDSS 2023): Machine Unlearning of Features and Labels
def BL2_prep_data(X_train: np.ndarray, model: TabularClassifier):
    # first copy origin model
    BL2_model = deepcopy(model)
    
    # add little change 
    len_X = X_train.shape[0]
    delta = np.array([0.01 for _ in range(len_X)])
    
    BL2_X_train = deepcopy(X_train)
    BL2_X_train[:, 0] = BL2_X_train[:, 0] + delta 
    
    return BL2_model, BL2_X_train


def train_BL2_model(BL2_X_train: np.ndarray, y_train: np.ndarray, X_train: np.ndarray, BL2_model: TabularClassifier, epochs: int, device: torch.device, batch_size: int=64, lr: int=0.001):
    # generate original dataset 
    ori_train_dataset = TabularDataset(X=X_train,y=y_train, device=device)
    ori_train_loader = DataLoader(dataset=ori_train_dataset, batch_size=batch_size, shuffle=False)
    
    # generate perturbed datastet
    BL2_train_dataset = TabularDataset(X=BL2_X_train, y=y_train, device=device)
    BL2_train_loader = DataLoader(dataset=BL2_train_dataset, batch_size=batch_size, shuffle=False)
    
    # init loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(BL2_model.parameters(), lr=lr)
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        BL2_model.train()
        running_loss = .0
        for (ori_inputs, labels), (BL2_inputs, _) in zip(ori_train_loader, BL2_train_loader):
            optimizer.zero_grad()
            ori_outputs = BL2_model(ori_inputs)
            BL2_outputs = BL2_model(BL2_inputs)
            
            ori_loss = criterion(ori_outputs, labels.to(device))
            BL2_loss = criterion(BL2_outputs, labels.to(device))
            
            # first-order update
            final_loss = -(BL2_loss - ori_loss)
            final_loss.backward()
            
            optimizer.step()
            running_loss += final_loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(BL2_train_loader)}")
        running_loss_lst.append(running_loss/len(BL2_train_loader))
    
    end_time = time.time()
    
    return BL2_model, end_time - start_time, running_loss_lst
    

# Baseline 3 (arXiv): Efficient Attribute Unlearning: Towards Selective Removal of Input Attributes from Feature Representations
class BL3RepDetExtractor(nn.Module):
    def __init__(self, input_dim: int):
        super(BL3RepDetExtractor, self).__init__() 
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        
        self.relu = nn.ReLU()
        
    
    def forward (self, x: torch.Tensor):
        x = self.relu(self.fc1(x))
        h = self.relu(self.fc2(x))
        
        return h


class BL3Classifier(nn.Module):
    def __init__(self, output_dim: int):
        super(BL3Classifier, self).__init__() 
        
        self.fc3 = nn.Linear(64, output_dim)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, h: torch.Tensor):
        h = self.fc3(h)
        
        return self.softmax(h)
    
    
class BL3DecoderMIhx(nn.Module):
    def __init__(self, output_dim: int):
        super(BL3DecoderMIhx, self).__init__()
        
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        x_hat = self.sigmoid(self.fc2(h))
        
        return x_hat 
    
class BL3MIhy(nn.Module):
    def __init__(self, output_dim: int):
        super(BL3MIhy, self).__init__() 
        
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        h = self.relu(self.fc2(h))
        y_hat = self.softmax(self.fc3(h))
        
        return y_hat 


class BL3MIhz(nn.Module):
    def __init__(self):
        super(BL3MIhz, self).__init__() 
        
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        
        self.relu = nn.ReLU()
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        h = self.relu(self.fc2(h))
        x_hat = self.fc3(h)
        
        return x_hat




def train_BL3Backbone(X_train: np.ndarray, y_train: np.ndarray, output_dim: int, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3RepDetExtractor, BL3Classifier, float, list]:
    # init dataset and dataloader 
    train_dataset = TabularDataset(X=X_train, y=y_train, device=device)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    # init classifier
    input_dim = X_train.shape[1]
    BL3RepDetExtractor_model = BL3RepDetExtractor(input_dim=input_dim).to(device)
    BL3Classifier_model = BL3Classifier(output_dim=output_dim).to(device)
    
    # init loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(chain(BL3RepDetExtractor_model.parameters(), BL3Classifier_model.parameters()), lr=lr)
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        BL3RepDetExtractor_model.train()
        BL3Classifier_model.train()
        running_loss = .0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            h = BL3RepDetExtractor_model(inputs)
            outputs = BL3Classifier_model(h)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")
        running_loss_lst.append(running_loss/len(train_loader))
    
    end_time = time.time()
    
    return BL3RepDetExtractor_model, BL3Classifier_model, end_time - start_time, running_loss_lst





def get_h_for_BL3(X_train: np.ndarray, device: torch.device, BL3RepDetExtractor_model: BL3RepDetExtractor, batch_size: int=64) -> np.ndarray:
    dataset = TensorDataset(torch.FloatTensor(X_train).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    
    h_lst = []
    BL3RepDetExtractor_model.eval()
    with torch.no_grad():
        for data in dataloader:
            data = data[0]
            h = BL3RepDetExtractor_model(data)
            h_lst.extend(h.cpu().tolist())
    
    return np.array(h_lst)


def train_BL3DecoderMIhx(h_train: np.ndarray, X_train: np.ndarray, device: torch.device, epochs: int, output_dim: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3DecoderMIhx, float]:
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.FloatTensor(X_train[:, 1:]).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3DecoderMIhx_model = BL3DecoderMIhx(output_dim=output_dim).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(BL3DecoderMIhx_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3DecoderMIhx_model.train()
        running_loss = .0
        for h, x in dataloader:
            x_hat = BL3DecoderMIhx_model(h)
            loss = criterion(x_hat, x)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
            
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
    
    end_time = time.time()
    
    return BL3DecoderMIhx_model, end_time-start_time


def cal_BL3MIhx(BL3DecoderMIhx_model: BL3DecoderMIhx, x: torch.Tensor, h: torch.Tensor):
    BL3DecoderMIhx_model.eval()
    
    criterion = nn.MSELoss()
    x_hat = BL3DecoderMIhx_model(h)
    loss = criterion(x_hat, x)
    
    MIhx = 0.5*torch.log(2*np.pi*np.e*torch.var(x)) - loss/h.shape[0]
    
    return MIhx
    
    

def train_BL3MIhy(h_train: np.ndarray, y_train: np.ndarray, device: torch.device, epochs: int, output_dim: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3MIhy, float]:
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.LongTensor(y_train).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3MIhy_model = BL3MIhy(output_dim=output_dim).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(BL3MIhy_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3MIhy_model.train()
        running_loss = .0
        for h, y in dataloader:
            y_hat = BL3MIhy_model(h)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
        
    end_time = time.time()
    
    return BL3MIhy_model, end_time-start_time

def cal_BL3MIhy(output_dim: int, BL3MIhy_model: BL3MIhy, h: torch.Tensor, y: torch.Tensor):
    BL3MIhy_model.eval()
    
    criterion = nn.CrossEntropyLoss()
    y_hat = BL3MIhy_model(h)
    loss = criterion(y_hat, y)
    
    MIhy = np.log(output_dim) - loss/h.shape[0]
    
    return MIhy

def train_BL3MIhz(h_train: np.ndarray, X_train: np.ndarray, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3MIhz, float]:
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.FloatTensor(X_train[:, 0:1]).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3MIhz_model = BL3MIhz().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(BL3MIhz_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3MIhz_model.train()
        running_loss = .0
        for h, z in dataloader:
            z_hat = BL3MIhz_model(h)
            loss = criterion(z_hat, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
        
    end_time = time.time()
    
    return BL3MIhz_model, end_time-start_time


def cal_BL3MIhz(BL3MIhz_model: BL3MIhz, h: torch.Tensor, z: torch.Tensor):
    BL3MIhz_model.eval()
    
    criterion = nn.MSELoss()
    z_hat = BL3MIhz_model(h)
    loss = criterion(z_hat, z)
    
    MIhz = 0.5*torch.log(2*np.pi*np.e*torch.var(z)) - loss/h.shape[0]
    
    return MIhz
    
class BL3Dataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, h: np.ndarray, device: torch.device):
        super(Dataset, self).__init__()
        self.X = torch.FloatTensor(X).to(device)
        self.y = torch.FloatTensor(y).to(device)
        self.h = torch.FloatTensor(h).to(device)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx], self.h[idx]
        
        

def train_unlearning_BL3(
    BL3RepDetExtractor_model: BL3RepDetExtractor, BL3Classifier_model: BL3Classifier, 
    BL3DecoderMIhx_model: BL3DecoderMIhx, BL3MIhy_model: BL3MIhy, BL3MIhz_model: BL3MIhz,
    X_train: np.ndarray, y_train: np.ndarray, device: torch.device,
    BL3_lamda1: float, BL3_lamda2: float, BL3_lamda3: float,
    epochs: float, output_dim: int, batch_size: int=64, lr: float=0.001
) -> Tuple[BL3RepDetExtractor, BL3Classifier, float, list, list]:
    train_dataset = TabularDataset(X=X_train, y=y_train, device=device) 
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    
    task_loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(chain(BL3RepDetExtractor_model.parameters(), BL3Classifier_model.parameters()), lr=lr)
    
    BL3DecoderMIhx_model.eval()
    BL3MIhy_model.eval()
    BL3MIhz_model.eval()
    
    task_loss_lst, MI_loss_lst = [], []
    start_time = time.time()
    for epoch in range(epochs):
        running_task_loss, running_MI_loss = .0, .0
        BL3RepDetExtractor_model.eval()
        BL3Classifier_model.eval()
        for inputs, labels in train_loader:
            h = BL3RepDetExtractor_model(inputs)
            preds = BL3Classifier_model(h)
            
            task_loss = task_loss_func(preds, labels)
            
            MIhx = cal_BL3MIhx(BL3DecoderMIhx_model=BL3DecoderMIhx_model, x=inputs[:, 1:], h=h)
            MIhy = cal_BL3MIhy(output_dim=output_dim, BL3MIhy_model=BL3MIhy_model, h=h, y=labels)
            MIhz = cal_BL3MIhz(BL3MIhz_model=BL3MIhz_model, h=h, z=inputs[:, 0:1])
            MIloss = -BL3_lamda1*MIhx - BL3_lamda2*MIhy - BL3_lamda3*MIhz
            
            final_loss = task_loss + MIloss
            
            optimizer.zero_grad()
            final_loss.backward()
            optimizer.step()
            
            running_task_loss += task_loss.cpu().item()
            running_MI_loss += MIloss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Running Task Loss: {running_task_loss/len(train_loader):.4f}")
        task_loss_lst.append(running_task_loss)
        MI_loss_lst.append(running_MI_loss)
        
    
    end_time = time.time()
    
    return BL3RepDetExtractor_model, BL3Classifier_model, end_time-start_time, task_loss_lst, MI_loss_lst


def evaluate_BL3(BL3RepDetExtractor_model: BL3RepDetExtractor, BL3Classifier_model: BL3Classifier, X_test: np.ndarray, y_test: np.ndarray, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    # init dataset and dataloader 
    test_dataset = TabularDataset(X=X_test, y=y_test, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) 

    BL3RepDetExtractor_model.eval()
    BL3Classifier_model.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            h = BL3RepDetExtractor_model(inputs)
            outputs = BL3Classifier_model(h)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true



















# # Define MINE 
# class MINE(nn.Module):
#     def __init__(self, input_dim: int, output_dim: int):
#         super(MINE, self).__init__()
#         self.fc1 = nn.Linear(input_dim + output_dim, 128)  # unlearned feature + n_classes
#         self.fc2 = nn.Linear(128, 64)
#         self.fc3 = nn.Linear(64, 1)
#         self.relu = nn.ReLU()
#         self.tanh = nn.Tanh()
    
#     def forward(self, x, y):
#         """
#         x: unlearned feature
#         y: n_class distribution
#         """
#         inputs = torch.cat((x, y), dim=1)  # Concatenate age and income
#         out = self.relu(self.fc1(inputs))
#         out = self.relu(self.fc2(out))
#         out = self.fc3(out)
#         # out = self.tanh(self.fc3(out))
        
#         return out
    

# def grad_norm(module):
#     parameters = module.parameters()
#     if isinstance(parameters, torch.Tensor):
#         parameters = [parameters]

#     parameters = list(filter(lambda p: p.grad is not None, parameters))

#     total_norm = 0
#     for p in parameters:
#         param_norm = p.grad.data.norm(2)
#         total_norm = param_norm.item() ** 2
#     total_norm = total_norm ** (1. / 2)
#     return total_norm


# def adaptive_gradient_clipping_(classifier: nn.Module, mi_module: nn.Module):
#     """
#     Clips the gradient according to the min norm of the generator and mi estimator

#     Arguments:
#         generator_module -- nn.Module 
#         mi_module -- nn.Module
#     """
#     norm_generator = grad_norm(classifier)
#     norm_estimator = grad_norm(mi_module)

#     min_norm = np.minimum(norm_generator, norm_estimator)

#     parameters = list(
#         filter(lambda p: p.grad is not None, mi_module.parameters()))
#     if isinstance(parameters, torch.Tensor):
#         parameters = [parameters]

#     for p in parameters:
#         p.grad.data.mul_(min_norm/(norm_estimator + 1e-9))

# # unlearn original model via MINE and VAE
# def deprecated_train_unlearning_with_MINE(init_MI_coeff: float, UL_model: TabularClassifier, X_train: np.ndarray, y_train: np.ndarray, device: torch.device, epochs: int, output_dim: int, batch_size: int=64, lr: float=0.001) -> Tuple[TabularClassifier, MINE, float, list, list, float]: 
#     # init dataset and dataloader
#     train_dataset = TabularDataset(X=X_train, y=y_train, device=device)
#     train_loader = DataLoader(dataset=train_dataset, batch_size=X_train.shape[0], shuffle=True)
    
#     MINE_model = MINE(input_dim=X_train.shape[1]-1, output_dim=output_dim).to(device)
    
#     # init loss function and optimizer
#     task_loss_func = nn.CrossEntropyLoss()
#     # UL_model_optimizer = optim.Adam(UL_model.parameters(), lr=lr*0.1)
#     # MINE_model_optimizer = optim.Adam(MINE_model.parameters(), lr=lr*0.001)
#     # optimizer = optim.Adam(chain(UL_model.parameters(), MINE_model.parameters()), lr=lr)
#     optimizer = optim.Adam(UL_model.parameters(), lr=lr)
    
#     # tar_MINE_model = deepcopy(MINE_model)
#     # tar_update_epoch = 1
#     # ma_rate = 0.9
    

#     MI_coeff = init_MI_coeff
#     # prev_avg_est_MI = None
#     task_loss_lst = []
#     start_time = time.time()
#     for epoch in range(epochs):
#         UL_model.train() 
#         MINE_model.train()
#         running_task_loss = 0
#         # cum_est_MI = .0
#         for inputs, labels in train_loader:
            
#             inputs[:, 0:1] = inputs[:, 0:1][torch.randperm(inputs.shape[0])] # shuffle unlearned feature!!!
            
#             outputs = UL_model(inputs) # get y_hat 
#             task_loss = task_loss_func(outputs, labels) # get task loss 
            
#             # ## generate x_i_hat from its marginal distribution via VAE
#             # # sample_ul_feature = inputs[:, 0:1][torch.randperm(outputs.shape[0])]
#             # sample_rem_features = inputs[:, 1:][torch.randperm(outputs.shape[0])] # get marginal distribution 
            
#             # # calculate MI for integrating into unlearning loss
#             # MINE_output_joint_dist= MINE_model.forward(x=inputs[:, 1:], y=outputs)
#             # MINE_output_product_margin_dist = MINE_model.forward(x=sample_rem_features, y=outputs)
            
#             # MI_first_term = MINE_output_joint_dist.mean()
#             # safe_vals = torch.clamp(MINE_output_product_margin_dist, min=1e-8)
#             # MI_second_term = torch.logsumexp(safe_vals, dim=0)[0] - np.log(safe_vals.shape[0])
            
#             # est_MI = MI_first_term - MI_second_term
#             # # est_MI_item = est_MI.detach().cpu().item()
#             # # cum_est_MI += est_MI_item
            
            
#             # MI_for_ULloss = -MI_coeff * (est_MI) # * cur_MINE_et / MINE_ma_et)
#             # MI_for_MINEloss = -est_MI
            
            
#             # final_loss = task_loss + MI_for_ULloss
#             # MI_for_ULloss = est_MI
            
#             # # first backward task loss and get gradients and its F-norm for the task loss part
#             # task_loss_grad = torch.autograd.grad(
#             #     task_loss, UL_model.parameters(), retain_graph=True, create_graph=True
#             # )
#             # task_loss_grad_norm = torch.sqrt(sum(torch.norm(grad)**2 for grad in task_loss_grad))
            
#             # # then backward MI loss and get corresponding gradients and F-norm
#             # MI_for_ULloss_grad = torch.autograd.grad(
#             #     MI_for_ULloss, UL_model.parameters(), retain_graph=True # create_graph=True
#             # )
#             # MI_for_ULloss_grad_norm = torch.sqrt(sum(torch.norm(grad)**2 for grad in MI_for_ULloss_grad))
            
#             # # MI Loss Gradients for MINE_model
#             # MI_for_MINEloss_grad = torch.autograd.grad(
#             #     MI_for_MINEloss, MINE_model.parameters()
#             # )
#             # MI_for_ULloss_MINE_grad_norm = torch.sqrt(sum(torch.norm(grad)**2 for grad in MI_for_ULloss_MINE_grad))

            

#             # # calculate the clipping factor to clip MI's gradient, because the MI does not have upper bound
#             # clip_factor = min(1.0, (task_loss_grad_norm / (MI_for_ULloss_grad_norm + 1e-8)).cpu().item())
#             # clip_factor = np.clip(clip_factor, a_min=1e-6, a_max=1e6) * 1.8
#             # clip_factor = 1.
            
#             # finalise update of unlearned model
#             # torch.nn.utils.clip_grad_norm_(UL_model.parameters(), 1.)
#             # final_loss.backward()
#             optimizer.zero_grad()
#             # with torch.no_grad():
#             #     for p, task_loss_grad_elem, MI_for_ULloss_grad_elem in zip(UL_model.parameters(), task_loss_grad, MI_for_ULloss_grad):
#             #         # ada_grad = MI_for_ULloss_grad_elem * clip_factor + task_loss_grad_elem
#             #         ada_grad = task_loss_grad_elem
#             #         p.grad = ada_grad
                    
#                 # for p, MI_for_MINEloss_grad_elem in zip(MINE_model.parameters(), MI_for_MINEloss_grad):
#                 #     p.grad = MI_for_MINEloss_grad_elem
#             # final_loss.backward()
#             task_loss.backward()
#             optimizer.step()
#             # adaptive_gradient_clipping_(classifier=UL_model, mi_module=MINE_model)
#             # final_loss.backward()

            
#             running_task_loss += task_loss.cpu().item()
            
#         # cur_avg_est_MI = cum_est_MI / len(train_loader)
#         # if prev_avg_est_MI is None:
#         #     prev_avg_est_MI = cur_avg_est_MI
#         # else:
#         #     est_MI_diff = max(cur_avg_est_MI, 0) - max(prev_avg_est_MI, 0)
#         #     if est_MI_diff >= 0: 
#         #         MI_coeff *= np.clip(max(cur_avg_est_MI, 0) / (max(prev_avg_est_MI, 0) + 1e-8), a_min=1.01, a_max=2)
#         #         MI_coeff = np.clip(MI_coeff, a_min=1e-5, a_max=100)
#         #     prev_avg_est_MI = cur_avg_est_MI
            
#         print(f"Epoch [{epoch+1}/{epochs}], Running Task Loss: {running_task_loss/len(train_loader):.4f},")
        
#         # # update MINE 
#         # if epoch % tar_update_epoch == 0:
#         #     UL_model.eval()
#         #     for MINE_epoch in range(1):
#         #         # for inputs, labels in train_loader:
#         #         inputs = torch.FloatTensor(X_train).to(device)
#         #         outputs = UL_model(inputs).detach()
#         #         sample_ul_feature = inputs[:, 0:1][torch.randperm(outputs.shape[0])]
                
#         #         MINE_output_joint_dist_for_MINEloss = MINE_model.forward(x=inputs[:, 0:1], y=outputs) # outputs.detach())
#         #         MINE_output_product_marin_dist_for_MINEloss = MINE_model.forward(x=sample_ul_feature, y=outputs) # outputs.detach())
                
#         #         MI_first_term_for_MINEloss = MINE_output_joint_dist_for_MINEloss.mean()
#         #         safe_vals = torch.clamp(MINE_output_product_marin_dist_for_MINEloss, min=1e-8)
#         #         MI_second_term_for_MINEloss = torch.logsumexp(safe_vals, dim=0) - np.log(safe_vals.shape[0])
                
#         #         MI_for_MINEloss = MI_first_term_for_MINEloss - MI_second_term_for_MINEloss # * cur_MINE_et / MINE_ma_et
#         #         MINEloss = -MI_for_MINEloss
                
#         #         MINE_model_optimizer.zero_grad()
#         #         MINEloss.backward()
#         #         # torch.nn.utils.clip_grad_norm_(MINE_model.parameters(), .1)
#         #         MINE_model_optimizer.step()
                    
#         #     UL_model.train()
        
#         task_loss_lst.append(running_task_loss/len(train_loader))
    
#     end_time = time.time()
        

#     return UL_model, MINE_model, task_loss_lst, end_time-start_time


















# # calculate estimated MI
# def cal_eMI(MINE_model: MINE, VAE_model: VAE, model: TabularClassifier, latent_dim: int, X_test: np.ndarray, y_test: np.ndarray, device: torch.device, batch_size: int=1) -> float:
#     # init dataset and dataloader 
#     test_dataset = TabularDataset(X=X_test, y=y_test, device=device)
#     test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
#     model.eval()
#     MINE_model.eval()
    
#     estimated_MI_lst = []
#     for inputs, _ in test_loader:
#         outputs = model(inputs)
        
#         z = torch.randn(batch_size, latent_dim).to(device)
#         sample_ul_feature = VAE_model.decode(z).detach().reshape(-1, 1)
        
#         MINE_output_joint_dist = MINE_model.forward(x=inputs[:, 0:1], y=outputs)
#         MINE_output_product_margin_dists = MINE_model.forward(x=sample_ul_feature, y=outputs)
        
#         MI_first_term = MINE_output_joint_dist.mean()
#         MI_second_term = torch.exp(MINE_output_product_margin_dists).mean()
        
#         estimated_MI = MI_first_term - torch.log(MI_second_term)
        
#         estimated_MI_lst.append(estimated_MI)
        
#     return np.mean(estimated_MI_lst)











# import torch.nn.functional as F

# # Define the VAE
# class VAE(nn.Module):
#     def __init__(self, latent_dim: int):
#         super(VAE, self).__init__()
        
#         # Encoder
#         self.fc1 = nn.Linear(1, 16)
#         self.fc_mu = nn.Linear(16, latent_dim)  # Mean of latent space
#         self.fc_logvar = nn.Linear(16, latent_dim)  # Log-variance of latent space
        
#         # Decoder
#         self.fc2 = nn.Linear(latent_dim, 16)
#         self.fc3 = nn.Linear(16, 1)
    
#     def encode(self, x):
#         h = F.relu(self.fc1(x))
#         mu = self.fc_mu(h)
#         logvar = self.fc_logvar(h)
#         return mu, logvar
    
#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std
    
#     def decode(self, z):
#         h = F.relu(self.fc2(z))
#         return self.fc3(h)
    
#     def forward(self, x):
#         mu, logvar = self.encode(x)
#         z = self.reparameterize(mu, logvar)
#         return self.decode(z), mu, logvar
    

# # VAE loss 
# def vae_loss(recon_x, x, mu, logvar):
#     # Reconstruction loss (MSE)
#     recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    
#     # KL Divergence
#     kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
#     return recon_loss + kl_div
    




# # train VAE 
# def train_VAE(df: pd.DataFrame, numerical_cols: list, device: torch.device, epochs: int, batch_size: int=64, latent_dim: int=2, lr: float=0.001) -> Tuple[VAE, float]:
#     # get unlearned feature and normalise it
#     ul_feature_arr = df[numerical_cols[0]].values.reshape(-1, 1)
#     numerical_scaler = StandardScaler()
#     ul_feature_arr_scaled = numerical_scaler.fit_transform(ul_feature_arr)
    
#     ul_feature_tensor_scaled = torch.FloatTensor(ul_feature_arr_scaled).to(device)
#     ul_feature_loader = DataLoader(TensorDataset(ul_feature_tensor_scaled), batch_size=batch_size, shuffle=True)
    
#     model = VAE(latent_dim=latent_dim)
#     optimizer = optim.Adam(model.parameters(), lr=lr)
    
#     model.train()
#     start_time = time.time()
#     for epoch in range(epochs):
#         total_loss = .0
#         for batch_ul_feature in ul_feature_loader:
#             batch_ul_feature = batch_ul_feature[0]
#             optimizer.zero_grad()
#             recon_x, mu, logvar = model(batch_ul_feature)
#             loss = vae_loss(recon_x, batch_ul_feature, mu, logvar)
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.cpu().item()
        
#         print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(ul_feature_loader.dataset):.4f}")
#     end_time = time.time()
    
    
#     model.eval()
    
#     with torch.no_grad():
#         z = torch.randn(df.shape[0], latent_dim)  # Sample from standard Gaussian
#         sampled_ul_features = model.decode(z).numpy()

#     # Inverse transform to original scale
#     sampled_ul_features = numerical_scaler.inverse_transform(sampled_ul_features)

#     # Plot the distribution of sampled ages
#     plt.hist(sampled_ul_features, bins=50, alpha=0.7, label='Sampled Ages')
#     plt.hist(df[numerical_cols[0]], bins=50, alpha=0.2, label='Original Ages')
#     plt.legend()
#     plt.xlabel('Unlearned Features')
#     plt.ylabel('Frequency')
#     plt.title('Comparison of Original and Sampled Unlearned Feature Distributions')
#     plt.show()
#     plt.close()
    
#     return model, end_time -start_time









# from lime.lime_tabular import LimeTabularExplainer
# # evaluated unlearned feature's feature importance via LIME
# def cal_LIME(X_test: np.ndarray, numerical_cols: list, nominial_cols: list, output_dim: int, device: torch.device, model: TabularClassifier) -> Tuple[pd.DataFrame, LimeTabularExplainer]:
#     # Create LIME Explainer
#     explainer = LimeTabularExplainer(
#         training_data=X_test,
#         mode='classification',
#         feature_names=numerical_cols+nominial_cols,  # [f'Feature {i}' for i in range(X_test.shape[1])],
#         class_names=['Class {}'.format(i) for i in range(output_dim)],
#         verbose=True,
#         random_state=42
#     ) 
    
#     # loop all test data
#     feature_importance = {} 
#     for idx in range(X_test.shape[0]):
#         exp = explainer.explain_instance(
#             data_row=X_test[idx:idx+1, :],
#             predict_fn=lambda x: model_wrapper_for_XAI(x, device=device, model=model)
#         )
#         for feature, weight in exp.as_list():
#             if feature in feature_importance:
#                 feature_importance[feature] += abs(weight)
#             else:
#                 feature_importance[feature] = abs(weight)
                
#     feature_importance = pd.DataFrame.from_dict(
#         feature_importance, orient='index', columns=['Importance']
#     ).sort_values(by='Importance', ascending=False)

    
#     return feature_importance, explainer


# def cal_BL3_LIME(X_test: np.ndarray, numerical_cols: list, nominial_cols: list, output_dim: int, device: torch.device, BL3RepDetExtractor_model: BL3RepDetExtractor, BL3Classifier_model: BL3Classifier) -> Tuple[pd.DataFrame, LimeTabularExplainer]:
#     # Create LIME Explainer
#     explainer = LimeTabularExplainer(
#         training_data=X_test,
#         mode='classification',
#         feature_names=numerical_cols+nominial_cols,  # [f'Feature {i}' for i in range(X_test.shape[1])],
#         class_names=['Class {}'.format(i) for i in range(output_dim)],
#         verbose=True,
#         random_state=42
#     ) 
    
#     # loop all test data
#     feature_importance = {} 
#     for idx in range(X_test.shape[0]):
#         exp = explainer.explain_instance(
#             data_row=X_test[idx:idx+1, :],
#             predict_fn=lambda x: model_wrapper_for_BL3_XAI(x, device=device, BL3RepDetExtractor_model=BL3Classifier_model, BL3Classifier_model=BL3Classifier_model)
#         )
#         for feature, weight in exp.as_list():
#             if feature in feature_importance:
#                 feature_importance[feature] += abs(weight)
#             else:
#                 feature_importance[feature] = abs(weight)
                
#     feature_importance = pd.DataFrame.from_dict(
#         feature_importance, orient='index', columns=['Importance']
#     ).sort_values(by='Importance', ascending=False)

#     # Plot feature importance
    
    
#     return feature_importance, explainer


# def plot_feature_importance_LIME(FI_df: pd.DataFrame):
#     FI_df.plot(kind='barh', figsize=(10, 8), legend=False)
#     plt.title('Global Feature Importance (Approx. from LIME)')
#     plt.show()
