import copy
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
from torchvision import utils
from typing import List, Union
def equisplit(targets, num_classes, retain_ratio, size_ratio):
    """
    Split the dataset into retain, forget, test, shadow_train, and shadow_test sets 
    in a balanced manner across all classes.
    """
    retain = []
    forget = []
    test = []
    all_targets = []
    shadow_train = []
    shadow_test = []

    # Convert targets to numpy array for easier manipulation
    targets = np.array(targets)

    # Loop through each class and perform the split
    for i in range(num_classes):
        size = int(5000 * size_ratio)  # Adjust the size based on the ratio
        idx_list = (targets == i).nonzero()[0][:size]  # Get indices of current class
        half_size = len(idx_list) // 2
        ret_size = int(half_size * retain_ratio)
        
        # Shuffle the indices to randomize the split
        rand_perm = torch.randperm(len(idx_list))
        target = idx_list[rand_perm[:half_size]]

        # Calculate sizes for retain, forget, and test sets
        forget_size = (half_size - ret_size) // 2
        test_size = half_size - ret_size - forget_size

        # Ensure forget and test sets have the same size
        assert forget_size == test_size, "Forget and test sizes must be equal"

        # Record all targets
        all_targets.extend(target.tolist())

        # Shuffle and split into retain, forget, and test sets
        randperm = torch.randperm(half_size)
        retain.extend(target[randperm[:ret_size]].tolist())
        forget.extend(target[randperm[ret_size:ret_size + forget_size]].tolist())
        test.extend(target[randperm[ret_size + forget_size:]].tolist())

        # Split the shadow data
        shadow_train.extend(idx_list[rand_perm[half_size:half_size + ret_size + forget_size]].tolist())
        shadow_test.extend(idx_list[rand_perm[half_size + ret_size + forget_size:]].tolist())

    # Return all splits
    return retain, forget, test, shadow_train, shadow_test

def confmat(n_C, conftype, num_change, exch_classes=None):
    # from EIU/
    mat = []
    for i in range(n_C):
        mat.append([])
        for j in range(n_C):
            mat[i].append(0)

    if conftype == 'noise':
        for i in range(n_C):
            for j in range(n_C):
                if i != j:  mat[i][j] = num_change

    if conftype == 'exchange':
        assert(exch_classes is not None)
        for i in exch_classes:
            for j in exch_classes:
                if i != j:  mat[i][j] = num_change

    return mat

def add_confusion(dataset: torch.utils.data.Dataset, confusion: List[List[int]]):
    num_classes = len(confusion)
    confset = copy.deepcopy(dataset)
    confused_indices = []
    targets = np.array(dataset.dataset.targets)[dataset.indices]
    for i in range(num_classes):
        class_i_indexes = np.flatnonzero(targets == i)
        last_rep = 0
        for j in range(num_classes):
            if i==j:
                continue
            for itr in range(confusion[i][j]): #Replace first C_{i, j} starting from last_rep indexes
                idx = class_i_indexes[last_rep + itr]
                confset.dataset.targets[dataset.indices[idx]] = j
                confused_indices.append(dataset.indices[idx])
            last_rep += confusion[i][j]

    return confset, confused_indices, dataset

def confuse_dataset(args, dataset):
    confusion = confmat(args.num_classes, 'exchange', args.num_change, args.exch_classes)
    confset, confused_indices, dataset =  add_confusion(dataset, confusion)
    return confset, confused_indices, dataset

def CIFAR10loader(batch_size, retain_ratio, args, size_ratio=1):
    # Normalize transformation
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Transformations for training and testing
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    # Load CIFAR10 dataset
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)

    # Class names
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    args.num_classes = 10

    # Split datasets
    retain_indices, forget_indices, test_indices, shadow_trsplit, shadow_tesplit = equisplit(
        trainset.targets, args.num_classes, retain_ratio, size_ratio)

    if args.eiu:
        all_dataset = Subset(trainset, retain_indices + forget_indices)
        confset, confused_indices, trainset = confuse_dataset(args, all_dataset)
        retain_indices = trainset.indices
        for item in confused_indices:
            retain_indices.remove(item)

        train_loader = DataLoader(confset, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)
        fgloader = DataLoader(Subset(confset.dataset, confused_indices), batch_size=batch_size, shuffle=True, num_workers=2)
        reloader = DataLoader(Subset(trainset.dataset, retain_indices), batch_size=batch_size, shuffle=True, num_workers=2)

        return confset, confused_indices, fgloader, reloader, train_loader, test_loader, classes

    # Data loaders for different splits
    rtloader = DataLoader(Subset(trainset, retain_indices + test_indices), batch_size=batch_size, shuffle=True, num_workers=2)
    rfloader = DataLoader(Subset(trainset, retain_indices + forget_indices), batch_size=batch_size, shuffle=True, num_workers=2)
    reloader = DataLoader(Subset(trainset, retain_indices), batch_size=batch_size, shuffle=True, num_workers=2)
    fgloader = DataLoader(Subset(trainset, forget_indices), batch_size=batch_size, shuffle=True, num_workers=2)
    teloader = DataLoader(Subset(trainset, test_indices), batch_size=batch_size, shuffle=True, num_workers=2)
    shadow_trainloader = DataLoader(Subset(trainset, shadow_trsplit), batch_size=batch_size, shuffle=True, num_workers=2)
    shadow_testloader = DataLoader(Subset(trainset, shadow_tesplit), batch_size=batch_size, shuffle=True, num_workers=2)

    return rtloader, rfloader, reloader, fgloader, teloader, shadow_trainloader, shadow_testloader, classes

def get_dataloader(dataset, batch_size, unlearn_ratio, args, size_ratio):
    if dataset == 'CIFAR10':
        return CIFAR10loader(batch_size, unlearn_ratio, args, size_ratio)
    else:
        print('Dataset not supported')
        return None, None, None