import time 
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms

import timm

from typing import Tuple
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score





def operator_retrain(area: list, img: np.ndarray) -> np.ndarray:
    if len(area) != 4:
        top_n, bottom_n, left_n, right_n, top, bottom, left, right = area
        img[top_n: bottom_n, left_n: right_n] = [255, 255, 255]
        img[top:bottom, left:right] = [255, 255, 255]
    else:
        top, bottom, left, right = area
        img[top:bottom, left:right] = [255, 255, 255]
    
    
    return img
    

def operator_shuffle(area: list, img: np.ndarray) -> np.ndarray:
    if len(area) != 4:
        top_n, bottom_n, left_n, right_n, top, bottom, left, right = area
        
        roi = img[top:bottom, left:right, :]
        roi_flat = roi.reshape(-1, 3)
        np.random.shuffle(roi_flat)
        roi_shuffled = roi_flat.reshape(roi.shape)
        # 8) Place the shuffled region back into the original image
        img[top:bottom, left:right, :] = roi_shuffled
        
        
        roi = img[top_n:bottom_n, left_n:right_n, :]
        roi_flat = roi.reshape(-1, 3)
        np.random.shuffle(roi_flat)
        roi_shuffled = roi_flat.reshape(roi.shape)
        # 8) Place the shuffled region back into the original image
        img[top_n:bottom_n, left_n:right_n, :] = roi_shuffled

    else:
        top, bottom, left, right = area

        roi = img[top:bottom, left:right, :]
        roi_flat = roi.reshape(-1, 3)
        np.random.shuffle(roi_flat)
        roi_shuffled = roi_flat.reshape(roi.shape)
        # 8) Place the shuffled region back into the original image
        img[top:bottom, left:right, :] = roi_shuffled
    
    
    return img


def operator_add_noise(area: list, img: np.ndarray, noise_level: float=5.) -> np.ndarray:
    if len(area) != 4:
        # Handle case with two regions
        top_n, bottom_n, left_n, right_n, top, bottom, left, right = area
        
        # Add noise to first region
        roi = img[top_n:bottom_n, left_n:right_n, :]
        noise = np.random.normal(0, noise_level, roi.shape).astype(np.int16)
        # Add noise and clip to valid range [0, 255]
        roi_noisy = np.clip(roi.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        img[top_n:bottom_n, left_n:right_n, :] = roi_noisy
        
        # Add noise to second region
        roi = img[top:bottom, left:right, :]
        noise = np.random.normal(0, noise_level, roi.shape).astype(np.int16)
        roi_noisy = np.clip(roi.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        img[top:bottom, left:right, :] = roi_noisy
        
    else:
        # Handle case with single region
        top, bottom, left, right = area
        
        # Add noise to the region
        roi = img[top:bottom, left:right, :]
        noise = np.random.normal(0, noise_level, roi.shape).astype(np.int16)
        roi_noisy = np.clip(roi.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        img[top:bottom, left:right, :] = roi_noisy


    
    
    return img




# Custom Dataset for CelebA
class CelebADataset(Dataset):
    def __init__(self, attr_dict: dict, device: torch.device) -> None:
        """
        Args:
            root_dir: Directory with all the images
            attr_dict: Dictionary with filenames as keys and gender labels as values
            transform: Optional transform to be applied on a sample
        """
        
        super(CelebADataset, self).__init__()
        
        self.attr_dict = attr_dict
        self.device = device

        self.root_dir = 'data_cv/img_align_celeba'

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        self.image_list = list(self.attr_dict.keys())
        
    
    def __len__(self) -> int:
        return len(self.image_list)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_name = self.image_list[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        img = self.transform(img)
            
        label = self.attr_dict[img_name]
        label = torch.tensor(label, dtype=torch.long)
        
        
        return img.to(self.device), label.to(self.device)





class Shuffle_CelebADataset(Dataset):
    def __init__(self, attr_dict: dict, area_dict: dict, where_to_unl: str, device: torch.device) -> None:
        """
        Args:
            root_dir: Directory with all the images
            attr_dict: Dictionary with filenames as keys and gender labels as values
            transform: Optional transform to be applied on a sample
        """
        
        super(Shuffle_CelebADataset, self).__init__()
        
        self.attr_dict = attr_dict
        self.area_dict = area_dict
        self.where_to_unl = where_to_unl
        self.device = device
        
        self.root_dir = f'data_cv/celeba_{where_to_unl}/shuffle'

            
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        self.image_list = list(self.attr_dict.keys())
        
    
    def __len__(self) -> int:
        return len(self.image_list)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_name = self.image_list[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        area = self.area_dict[img_name]
        img = operator_shuffle(area=area, img=img) 

        img = Image.fromarray(img)
        img = self.transform(img)
            
        label = self.attr_dict[img_name]
        label = torch.tensor(label, dtype=torch.long)
        
        
        return img.to(self.device), label.to(self.device)


class Retrain_CelebADataset(Dataset):
    def __init__(self, attr_dict: dict, where_to_unl: str, device: torch.device) -> None:
        """
        Args:
            root_dir: Directory with all the images
            attr_dict: Dictionary with filenames as keys and gender labels as values
            transform: Optional transform to be applied on a sample
        """
        
        super(Retrain_CelebADataset, self).__init__()
        
        self.attr_dict = attr_dict
        self.where_to_unl = where_to_unl
        self.device = device 
        

        self.root_dir = f'data_cv/celeba_{where_to_unl}/retrain'

            
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        self.image_list = list(self.attr_dict.keys())
        
    
    def __len__(self) -> int:
        return len(self.image_list)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_name = self.image_list[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)        
        img = Image.fromarray(img)
        img = self.transform(img)
            
        label = self.attr_dict[img_name]
        label = torch.tensor(label, dtype=torch.long)
        
        
        return img.to(self.device), label.to(self.device)



class BL2_CelebADataset(Dataset):
    def __init__(self, attr_dict: dict, where_to_unl: str, device: torch.device) -> None:
        """
        Args:
            root_dir: Directory with all the images
            attr_dict: Dictionary with filenames as keys and gender labels as values
            transform: Optional transform to be applied on a sample
        """
        
        super(BL2_CelebADataset, self).__init__()
        
        self.attr_dict = attr_dict
        self.where_to_unl = where_to_unl
        self.device = device 
        
        self.root_dir = f'data_cv/BL2_celeba_{where_to_unl}'
        

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        self.image_list = list(self.attr_dict.keys())
        
    
    def __len__(self) -> int:
        return len(self.image_list)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_name = self.image_list[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        img = self.transform(img)
            
        label = self.attr_dict[img_name]
        label = torch.tensor(label, dtype=torch.long)
        
        
        return img.to(self.device), label.to(self.device)



class BL2_RASI_CelebADataset(Dataset):
    def __init__(self, attr_dict: dict, area_dict: dict, where_to_unl: str, device: torch.device) -> None:
        """
        Args:
            root_dir: Directory with all the images
            attr_dict: Dictionary with filenames as keys and gender labels as values
            transform: Optional transform to be applied on a sample
        """
        
        super(BL2_RASI_CelebADataset, self).__init__()
        
        self.attr_dict = attr_dict
        self.area_dict = area_dict
        self.where_to_unl = where_to_unl
        self.device = device 
        
        self.root_dir = f'data_cv/BL2_celeba_{where_to_unl}'
        

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        self.image_list = list(self.attr_dict.keys())
        
    
    def __len__(self) -> int:
        return len(self.image_list)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_name = self.image_list[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        area = self.area_dict[img_name]
        img = operator_shuffle(area=area, img=img) 
        
        
        img = Image.fromarray(img)
        img = self.transform(img)
            
        label = self.attr_dict[img_name]
        label = torch.tensor(label, dtype=torch.long)
        
        
        return img.to(self.device), label.to(self.device)




def prep_data(label: str='Male', where_to_unl: str='nose', retrain_or_shuffle: str='retrain', train_test_ratio: float=.8, random_seed: int=2025) -> Tuple[dict, dict, dict, dict]:
    with open(f'data_cv/label_{label}.dict', 'rb') as f:
        attr_dict = pickle.load(f)
    with open(f'data_cv/{where_to_unl}_{retrain_or_shuffle}_area.dict', 'rb') as f:
        area_dict = pickle.load(f)
    
        

    keys = list(attr_dict.keys())

    n = len(keys)
    train_keys = keys[:int(train_test_ratio*n)]
    test_keys = keys[int(train_test_ratio*n):]
    
    train_attr_dict = {k: attr_dict[k] for k in train_keys}
    test_attr_dict = {k: attr_dict[k] for k in test_keys}
    train_area_dict = {k: area_dict[k] for k in train_keys}
    test_area_dict = {k: area_dict[k] for k in test_keys}
    
    
    return train_attr_dict, test_attr_dict, train_area_dict, test_area_dict




# Custom model class that uses frozen ViT as backbone with MLP classifier
class ViTBackBone(nn.Module):
    def __init__(self):
        super(ViTBackBone, self).__init__()
        
        # Load pretrained ViT model
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False)
        self.backbone.load_state_dict(torch.load('./pretrained_models/pytorch_model.bin'))
        

        # Remove the original classification head
        self.backbone.head = nn.Identity()
        
        # Freeze the backbone parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
            
    def forward(self, x: torch.Tensor):
        with torch.no_grad():
            features = self.backbone(x)
            
        return features
        

class ViTGenderClassifier(nn.Module):
    def __init__(self, embedding_dim: int, num_classes: int=2) -> None:
        """
        Initialize the ViT-based gender classifier
        
        Args:
            num_classes: Number of output classes (2 for binary gender classification)
        """
        
        super(ViTGenderClassifier, self).__init__()
        
        
        
        # Define our own MLP classifier (which will be trained)
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the model
        
        Args:
            x: Input tensor of shape [batch_size, 3, 224, 224]
            
        Returns:
            Tensor of logits with shape [batch_size, num_classes]
        """
        # Pass through the trainable classifier
        output = self.classifier(features)
        return output





    


def train_ori_classifier(train_attr_dict: dict, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[nn.Module, nn.Module, float, list]:
    train_dataset = CelebADataset(attr_dict=train_attr_dict, device=device)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    

    vit_backbone = ViTBackBone().to(device)
    embedding_dim = vit_backbone.backbone.embed_dim
    vit_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(vit_classifier.parameters(), lr=lr)
    
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        vit_backbone.train()
        vit_classifier.train()
        running_loss = .0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            features = vit_backbone(inputs)
            outputs = vit_classifier(features)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.8f}")
        running_loss_lst.append(running_loss/len(train_loader))
    
    end_time = time.time()
    
    return vit_backbone, vit_classifier, end_time - start_time, running_loss_lst


def train_retrain_classifier(train_attr_dict: dict, where_to_unl: str, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[nn.Module, nn.Module, float, list]:
    train_dataset = Retrain_CelebADataset(attr_dict=train_attr_dict, where_to_unl=where_to_unl, device=device)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    

    vit_backbone = ViTBackBone().to(device)
    embedding_dim = vit_backbone.backbone.embed_dim
    vit_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(vit_classifier.parameters(), lr=lr)
    
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        vit_backbone.train()
        vit_classifier.train()
        running_loss = .0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            features = vit_backbone(inputs)
            outputs = vit_classifier(features)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.8f}")
        running_loss_lst.append(running_loss/len(train_loader))
    
    end_time = time.time()
    
    return vit_backbone, vit_classifier, end_time - start_time, running_loss_lst


def train_BL2_classifier(backbone: nn.Module, classifier: nn.Module, train_attr_dict: dict, where_to_unl: str, epochs: int, device: torch.device, batch_size: int=64, lr: int=0.001) -> Tuple[nn.Module, float, list]:
    ori_train_dataset = CelebADataset(attr_dict=train_attr_dict, device=device)
    ori_train_loader = DataLoader(dataset=ori_train_dataset, batch_size=batch_size, shuffle=False)
    
    BL2_train_dataset = BL2_CelebADataset(attr_dict=train_attr_dict, where_to_unl=where_to_unl, device=device) 
    BL2_train_loader = DataLoader(dataset=BL2_train_dataset, batch_size=batch_size, shuffle=False)
    
    # init loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=lr)
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        backbone.train()
        classifier.train()
        running_loss = .0
        for (ori_inputs, labels), (BL2_inputs, _) in zip(ori_train_loader, BL2_train_loader):
            optimizer.zero_grad()
            
            ori_features = backbone(ori_inputs)
            ori_outputs = classifier(ori_features)
            BL2_features = backbone(BL2_inputs)
            BL2_outputs = classifier(BL2_features)
            
            ori_loss = criterion(ori_outputs, labels.to(device))
            BL2_loss = criterion(BL2_outputs, labels.to(device))
            
            # first-order update
            final_loss = -(BL2_loss - ori_loss)
            final_loss.backward()
            
            optimizer.step()
            running_loss += final_loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(BL2_train_loader)}")
        running_loss_lst.append(running_loss/len(BL2_train_loader))
    
    end_time = time.time()
    
    return backbone, classifier, end_time - start_time, running_loss_lst
    


def train_shuffle_classifier(UL_backbone: nn.Module, UL_classifier: nn.Module, train_attr_dict: dict, train_area_dict: dict, where_to_unl: str, device: torch.device, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[nn.Module, nn.Module, float, list]:
    train_dataset = Shuffle_CelebADataset(attr_dict=train_attr_dict, area_dict=train_area_dict, where_to_unl=where_to_unl, device=device) 
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(UL_classifier.parameters(), lr=lr)
    
    
    # training loop 
    running_loss_lst = []
    start_time = time.time()
    for epoch in range(epochs):
        UL_backbone.train()
        UL_classifier.train()
        running_loss = .0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            features = UL_backbone(inputs)
            outputs = UL_classifier(features)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.8f}")
        running_loss_lst.append(running_loss/len(train_loader))
    
    end_time = time.time()
    
    return UL_backbone, UL_classifier, end_time - start_time, running_loss_lst




def evaluate_ori_classifier(backbone: nn.Module, classifier: nn.Module, test_attr_dict: dict, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = CelebADataset(attr_dict=test_attr_dict, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    backbone.eval()
    classifier.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            features = backbone(inputs)
            outputs = classifier(features)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true



def evaluate_retrain_classifier(backbone: nn.Module, classifier: nn.Module, test_attr_dict: dict, where_to_unl: str, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = Retrain_CelebADataset(attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    backbone.eval()
    classifier.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            features = backbone(inputs)
            outputs = classifier(features)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true



def evaluate_BL2_classifier(backbone: nn.Module, classifier: nn.Module, test_attr_dict: dict, where_to_unl: str, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = BL2_CelebADataset(attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    backbone.eval()
    classifier.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            features = backbone(inputs)
            outputs = classifier(features)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true


def evaluate_BL2_RASI_classifier(backbone: nn.Module, classifier: nn.Module, test_attr_dict: dict, test_area_dict: dict, where_to_unl: str, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = BL2_RASI_CelebADataset(attr_dict=test_attr_dict, area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    backbone.eval()
    classifier.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            features = backbone(inputs)
            outputs = classifier(features)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true

def evaluate_shuffle_classifier(backbone: nn.Module, classifier: nn.Module, test_attr_dict: dict, test_area_dict: dict, where_to_unl: str, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = Shuffle_CelebADataset(attr_dict=test_attr_dict, area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    backbone.eval()
    classifier.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            features = backbone(inputs)
            outputs = classifier(features)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true




class BL3DecoderMIhx(nn.Module):
    def __init__(self, embedding_dim: int, output_dim: int):
        super(BL3DecoderMIhx, self).__init__()
        
        self.fc1 = nn.Linear(embedding_dim, 64)
        self.fc2 = nn.Linear(64, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        x_hat = self.sigmoid(self.fc2(h))
        
        return x_hat 
    
    
    
class BL3MIhy(nn.Module):
    def __init__(self, embedding_dim: int, output_dim: int=2):
        super(BL3MIhy, self).__init__() 
        
        self.fc1 = nn.Linear(embedding_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        h = self.relu(self.fc2(h))
        y_hat = self.softmax(self.fc3(h))
        
        return y_hat 


class BL3MIhz(nn.Module):
    def __init__(self, embedding_dim: int, output_dim: int):
        super(BL3MIhz, self).__init__() 
        
        self.fc1 = nn.Linear(embedding_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, h: torch.Tensor):
        h = self.relu(self.fc1(h))
        h = self.relu(self.fc2(h))
        x_hat = self.fc3(h)
        
        return x_hat





def get_h_for_BL3(attr_dict: dict, device: torch.device, backbone: ViTBackBone, batch_size: int=64) -> np.ndarray:
    dataset = CelebADataset(attr_dict=attr_dict, device=device)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    
    h_lst = []
    backbone.eval()
    with torch.no_grad():
        for data in dataloader:
            data = data[0]
            h = backbone(data)
            h_lst.extend(h.cpu().tolist())
    
    return np.array(h_lst)




def train_BL3DecoderMIhx(h_train: np.ndarray, attr_dict: dict, device: torch.device, backbone: ViTBackBone, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3DecoderMIhx, float]:
    dataset = CelebADataset(attr_dict=attr_dict, device=device)
    dataloader = DataLoader(dataset=dataset, batch_size=len(dataset), shuffle=False)
    for inputs, _ in dataloader:
        X = inputs.detach().cpu().numpy()
    
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.FloatTensor(X).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3DecoderMIhx_model = BL3DecoderMIhx(output_dim=int(224*224), embedding_dim=backbone.backbone.embed_dim).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(BL3DecoderMIhx_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3DecoderMIhx_model.train()
        running_loss = .0
        for h, x in dataloader:
            x_hat = BL3DecoderMIhx_model(h)
            loss = criterion(x_hat, x)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
            
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
    
    end_time = time.time()
    
    return BL3DecoderMIhx_model, end_time-start_time


def cal_BL3MIhx(BL3DecoderMIhx_model: BL3DecoderMIhx, x: torch.Tensor, h: torch.Tensor):
    BL3DecoderMIhx_model.eval()
    
    criterion = nn.MSELoss()
    x_hat = BL3DecoderMIhx_model(h)
    loss = criterion(x_hat, x)
    
    MIhx = 0.5*torch.log(2*np.pi*np.e*torch.var(x)) - loss/h.shape[0]
    
    return MIhx
    
    

def train_BL3MIhy(h_train: np.ndarray, attr_dict: dict, device: torch.device, backbone: ViTBackBone, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3MIhy, float]:
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.LongTensor(list(attr_dict.values())).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3MIhy_model = BL3MIhy(embedding_dim=backbone.backbone.embed_dim).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(BL3MIhy_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3MIhy_model.train()
        running_loss = .0
        for h, y in dataloader:
            y_hat = BL3MIhy_model(h)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
        
    end_time = time.time()
    
    return BL3MIhy_model, end_time-start_time


def cal_BL3MIhy(output_dim: int, BL3MIhy_model: BL3MIhy, h: torch.Tensor, y: torch.Tensor):
    BL3MIhy_model.eval()
    
    criterion = nn.CrossEntropyLoss()
    y_hat = BL3MIhy_model(h)
    loss = criterion(y_hat, y)
    
    MIhy = np.log(output_dim) - loss/h.shape[0]
    
    return MIhy

def train_BL3MIhz(h_train: np.ndarray, attr_dict: dict, device: torch.device, backbone: ViTBackBone, epochs: int, batch_size: int=64, lr: float=0.001) -> Tuple[BL3MIhz, float]:
    dataset = CelebADataset(attr_dict=attr_dict, device=device)
    dataloader = DataLoader(dataset=dataset, batch_size=len(dataset), shuffle=False)
    for inputs, _ in dataloader:
        X = inputs.detach().cpu().numpy()
    
    dataset = TensorDataset(torch.FloatTensor(h_train).to(device), torch.FloatTensor(X).to(device))
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    BL3MIhz_model = BL3MIhz(output_dim=int(224*224), embedding_dim=backbone.backbone.embed_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(BL3MIhz_model.parameters(), lr=lr)
    
    start_time = time.time()
    for epoch in range(epochs):
        BL3MIhz_model.train()
        running_loss = .0
        for h, z in dataloader:
            z_hat = BL3MIhz_model(h)
            loss = criterion(z_hat, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.cpu().item()
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")
        
    end_time = time.time()
    
    return BL3MIhz_model, end_time-start_time


def cal_BL3MIhz(BL3MIhz_model: BL3MIhz, h: torch.Tensor, z: torch.Tensor):
    BL3MIhz_model.eval()
    
    criterion = nn.MSELoss()
    z_hat = BL3MIhz_model(h)
    loss = criterion(z_hat, z)
    
    MIhz = 0.5*torch.log(2*np.pi*np.e*torch.var(z)) - loss/h.shape[0]
    
    return MIhz


    
class BL3Dataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, h: np.ndarray, device: torch.device):
        super(Dataset, self).__init__()
        self.X = torch.FloatTensor(X).to(device)
        self.y = torch.FloatTensor(y).to(device)
        self.h = torch.FloatTensor(h).to(device)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx], self.h[idx]
        
        

def train_unlearning_BL3(
    backbone: ViTBackBone, BL3Classifier_model: nn.Module, 
    BL3DecoderMIhx_model: BL3DecoderMIhx, BL3MIhy_model: BL3MIhy, BL3MIhz_model: BL3MIhz,
    attr_dict: dict, device: torch.device,
    BL3_lamda1: float, BL3_lamda2: float, BL3_lamda3: float,
    epochs: float, batch_size: int=64, lr: float=0.001
) -> Tuple[nn.Module, float, list, list]:
    
    train_dataset = CelebADataset(attr_dict=attr_dict, device=device)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    

    
    task_loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(BL3Classifier_model.parameters(), lr=lr)
    
    backbone.eval()
    BL3DecoderMIhx_model.eval()
    BL3MIhy_model.eval()
    BL3MIhz_model.eval()
    
    task_loss_lst, MI_loss_lst = [], []
    start_time = time.time()
    for epoch in range(epochs):
        running_task_loss, running_MI_loss = .0, .0
        BL3Classifier_model.eval()
        for inputs, labels in train_loader:
            h = backbone(inputs)
            preds = BL3Classifier_model(h)
            
            task_loss = task_loss_func(preds, labels)
            
            MIhx = cal_BL3MIhx(BL3DecoderMIhx_model=BL3DecoderMIhx_model, x=inputs, h=h)
            MIhy = cal_BL3MIhy(output_dim=2, BL3MIhy_model=BL3MIhy_model, h=h, y=labels)
            MIhz = cal_BL3MIhz(BL3MIhz_model=BL3MIhz_model, h=h, z=inputs)
            MIloss = -BL3_lamda1*MIhx - BL3_lamda2*MIhy - BL3_lamda3*MIhz
            
            final_loss = task_loss + MIloss
            
            optimizer.zero_grad()
            final_loss.backward()
            optimizer.step()
            
            running_task_loss += task_loss.cpu().item()
            running_MI_loss += MIloss.cpu().item()
            
        print(f"Epoch [{epoch+1}/{epochs}], Running Task Loss: {running_task_loss/len(train_loader):.4f}")
        task_loss_lst.append(running_task_loss)
        MI_loss_lst.append(running_MI_loss)
        
    
    end_time = time.time()
    
    return BL3Classifier_model, end_time-start_time, task_loss_lst, MI_loss_lst




def evaluate_BL3(backbone: ViTBackBone, BL3Classifier_model: nn.Module, attr_dict: dict, device: torch.device, batch_size: int=64) -> Tuple[float, list, list]:
    test_dataset = CelebADataset(attr_dict=attr_dict, device=device)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    backbone.eval()
    BL3Classifier_model.eval()
    y_pred = []
    y_true = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            h = backbone(inputs)
            outputs = BL3Classifier_model(h)
            preds = torch.argmax(outputs, dim=1)
            y_pred.extend(preds.cpu().tolist())
            y_true.extend(labels.cpu().tolist())

    # Calculate Metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    # print(f"F1 Score: {f1:.4f}")
    
    return accuracy, y_pred, y_true
