import sys
import numpy as np
import torch
import os
import pickle
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn.functional as F
from torch.utils.data import Dataset

class CLCIFAR20(Dataset):
    """CLCIFAR10 training set

    The training set of CIFAR10 with human annotated complementary labels.
    Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels

    Args:
        root: the path to store the dataset
        transform: feature transformation function
    """
    def __init__(self, root="./dataset", transform=None):

        #os.makedirs(os.path.join(root, 'clcifar20'), exist_ok=True)
        dataset_path = os.path.join(root, 'clcifar20', f"clcifar20.pkl")

        data = pickle.load(open(dataset_path, "rb"))

        self.transform = transform
        self.input_dim = 32 * 32 * 3
        self.num_classes = 20
        self.data = data["images"]
        #self.targets = [labels[0] for labels in data["cl_labels"]]
        self.targets = np.zeros((len(self.data), self.num_classes))
        for i in range(len(self.data)):
            self.targets[i, data["cl_labels"][i][0]] = 1
        
        self.ord_labels = data["ord_labels"]
        self.ord_labels = np.array(self.ord_labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, self.targets[index], self.ord_labels[index], index

class CLCIFAR10(Dataset):
    """CLCIFAR10 training set

    The training set of CIFAR10 with human annotated complementary labels.
    Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels

    Args:
        root: the path to store the dataset
        transform: feature transformation function
    """
    def __init__(self, root="./dataset", transform=None):

        #os.makedirs(os.path.join(root, 'clcifar10'), exist_ok=True)
        dataset_path = os.path.join(root, 'clcifar10', f"clcifar10.pkl")

        data = pickle.load(open(dataset_path, "rb"))

        self.transform = transform
        self.input_dim = 32 * 32 * 3
        self.num_classes = 10
        self.data = data["images"]

        self.targets = np.zeros((len(self.data), self.num_classes))
        for i in range(len(self.data)):
            self.targets[i, data["cl_labels"][i][0]] = 1
        self.ord_labels = data["ord_labels"]
        self.ord_labels = np.array(self.ord_labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, self.targets[index], self.ord_labels[index], index

def class_prior(complementary_labels):
    return np.bincount(complementary_labels) / len(complementary_labels)
'''
def prepare_mnist_data(batch_size):
    ordinary_train_dataset = dsets.MNIST(root='./data/mnist', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset = dsets.MNIST(root='./data/mnist', train=False, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    full_train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=len(ordinary_train_dataset.data), shuffle=True)
    num_classes = len(ordinary_train_dataset.classes)
    return full_train_loader, train_loader, test_loader, ordinary_train_dataset, test_dataset, num_classes
'''
def prepare_cv_datasets(dataname, batch_size):
    if dataname == 'clcifar10':
        train_transform = transforms.Compose(
            [transforms.ToTensor(),  
            transforms.RandomHorizontalFlip(), 
            transforms.RandomCrop(32,4),
            transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))])
        train_dataset = CLCIFAR10(transform=train_transform)
        ordinary_train_dataset = dsets.CIFAR10(root='./dataset', train=True, transform=train_transform, download=True)
        test_dataset = dsets.CIFAR10(root='./dataset', train=False, transform=test_transform)
        num_classes = 10
    if dataname == 'clcifar20':
        train_transform = transforms.Compose(
            [transforms.ToTensor(),  
            transforms.RandomHorizontalFlip(), 
            transforms.RandomCrop(32,4),
            transforms.Normalize((0.5068, 0.4854, 0.4402), (0.2672, 0.2563, 0.2760))])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.5068, 0.4854, 0.4402), (0.2672, 0.2563, 0.2760))])
        train_dataset = CLCIFAR20(transform=train_transform)
        ordinary_train_dataset = dsets.CIFAR100(root='./dataset', train=True, transform=test_transform, download=True)
        test_dataset = dsets.CIFAR100(root='./dataset', train=False, transform=test_transform)
        def _cifar100_to_cifar20(target):
            _dict = {0: 4, 1: 1, 2: 14, 3: 8, 4: 0, 5: 6, 6: 7, 7: 7, 8: 18, 9: 3, 10: 3, 11: 14, 12: 9, 13: 18, 14: 7, 15: 11, 16: 3, 17: 9, 18: 7, 19: 11, 20: 6, 21: 11, 22: 5, 23: 10, 24: 7, 25: 6, 26: 13, 27: 15, 28: 3, 29: 15, 30: 0, 31: 11, 32: 1, 33: 10, 34: 12, 35: 14, 36: 16, 37: 9, 38: 11, 39: 5, 40: 5, 41: 19, 42: 8, 43: 8, 44: 15, 45: 13, 46: 14, 47: 17, 48: 18, 49: 10, 50: 16, 51: 4, 52: 17, 53: 4, 54: 2, 55: 0, 56: 17, 57: 4, 58: 18, 59: 17, 60: 10, 61: 3, 62: 2, 63: 12, 64: 12, 65: 16, 66: 12, 67: 1, 68: 9, 69: 19, 70: 2, 71: 10, 72: 0, 73: 1, 74: 16, 75: 12, 76: 9, 77: 13, 78: 15, 79: 13, 80: 16, 81: 18, 82: 2, 83: 4, 84: 6, 85: 19, 86: 5, 87: 5, 88: 8, 89: 19, 90: 18, 91: 1, 92: 2, 93: 15, 94: 6, 95: 0, 96: 17, 97: 8, 98: 14, 99: 13}
            return _dict[target]  
        test_dataset.targets = [_cifar100_to_cifar20(i) for i in test_dataset.targets]
        num_classes = 20     
    complementary_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    full_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset.data), shuffle=False, num_workers=0)
    
    return full_train_loader, complementary_train_loader, train_loader, test_loader, train_dataset, test_dataset, num_classes

def prepare_train_loaders(full_train_loader):
    for i, (data, labels, true_labels, index) in enumerate(full_train_loader):
        K = torch.max(true_labels)+1 # K is number of classes, full_train_loader is full batch
    dim = int(data.reshape(-1).shape[0]/data.shape[0])
    return dim, labels