import os
import pickle
import random
import tarfile
import urllib.request

import numpy as np
import pandas as pd
import torch
import torchvision
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, TensorDataset
from PIL import Image
import io
import math

# np.random.seed(0)
# random.seed(0)
# torch.random.manual_seed(0)
# torch.manual_seed(0)
# torch.cuda.manual_seed_all(0)


def train_val_split(labels, n_labeled, positive_label_list):
    labels = np.array(labels)
    label_types = np.unique(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    n_labeled_per_class = int(n_labeled / len(positive_label_list))

    num_all = len(labels)
    num_positive = 0

    for i in label_types:
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)

        train_unlabeled_idxs.extend(idxs)
        if i in positive_label_list:
            train_labeled_idxs.extend(idxs[:n_labeled_per_class])
            num_positive += len(idxs)
        #     train_unlabeled_idxs.extend(idxs[n_labeled_per_class:])
        # else:
        #     train_unlabeled_idxs.extend(idxs)

    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)

    prior = num_positive / num_all

    return train_labeled_idxs, train_unlabeled_idxs, prior


def train_val_split_prior(labels, n_labeled, positive_label_list, target_prior):
    
    labels = np.array(labels)
    label_types = np.unique(labels)

    positive_class = positive_label_list[0]  # Single positive class
    positive_idxs = np.where(labels == positive_class)[0]
    np.random.shuffle(positive_idxs)
    num_positive_total = len(positive_idxs)
    train_labeled_idxs = list(positive_idxs[:n_labeled])

    num_negative_total = int(num_positive_total * (1 - target_prior) / target_prior)
    negative_classes = [c for c in label_types if c not in positive_label_list]
    num_negative_classes = len(negative_classes)
    samples_per_negative_class = int(num_negative_total / num_negative_classes)

    negative_idxs_for_U = []
    for neg_class in negative_classes:
        neg_idxs = np.where(labels == neg_class)[0]
        np.random.shuffle(neg_idxs)
        selected = neg_idxs[:samples_per_negative_class]
        negative_idxs_for_U.extend(selected)
    train_unlabeled_idxs = list(positive_idxs) + negative_idxs_for_U

    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)

    actual_prior = num_positive_total / len(train_unlabeled_idxs)

    print(f"\n=== Binary PU Data Split ===")
    print(f"Target prior: {target_prior:.4f}")
    print(f"Actual prior: {actual_prior:.4f}")
    print(f"============================\n")

    return train_labeled_idxs, train_unlabeled_idxs, actual_prior


def normalise_fashionmnist(x, mean, std):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean
    x /= std
    return x


def _3D_to_4(x):
    return x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])


def normalise(x, mean, std):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean * 255
    x *= 1.0 / (255 * std)
    return x


def transpose(x, source='NHWC', target='NCHW'):
    '''
    N: batch size
    H: height
    W: weight
    C: channel
    '''
    return x.transpose([source.index(d) for d in target])


class FashionMNIST_labeled(torchvision.datasets.FashionMNIST):
    def __init__(self,
                 root,
                 indexs=None,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=True):
        super(FashionMNIST_labeled,
              self).__init__(root,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = _3D_to_4(
            normalise_fashionmnist(self.data, mean=(0.2860, ), std=(0.3530, )))

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class FashionMNIST_unlabeled(FashionMNIST_labeled):
    def __init__(self,
                 root,
                 indexs,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=True):
        super(FashionMNIST_unlabeled,
              self).__init__(root,
                             indexs,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        # self.targets = np.array([-1 for i in range(len(self.targets))])


def get_fashionMNIST_data(num_labeled,
                          positive_label_list,
                          root,
                          prior,
                          transform_train=None,
                          transform_val=None):
    base_dataset = torchvision.datasets.FashionMNIST(root,
                                                     train=True,
                                                     download=True)
    # train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split_prior(
    #     base_dataset.targets, num_labeled, positive_label_list, prior)
    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        base_dataset.targets, num_labeled, positive_label_list)
    target_transform = lambda x: 1 if x in positive_label_list else -1

    train_labeled_dataset = FashionMNIST_labeled(
        root,
        train_labeled_idxs,
        train=True,
        transform=transform_train,
        target_transform=target_transform)
    train_unlabeled_dataset = FashionMNIST_unlabeled(
        root,
        train_unlabeled_idxs,
        train=True,
        transform=transform_train,
        target_transform=target_transform)
    val_dataset = FashionMNIST_labeled(root,
                                       train=True,
                                       transform=transform_train,
                                       target_transform=target_transform)
    test_dataset = FashionMNIST_labeled(root,
                                        train=False,
                                        transform=transform_val,
                                        download=True,
                                        target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, prior


class MNIST_labeled(torchvision.datasets.MNIST):
    def __init__(self,
                 root,
                 indexs=None,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=True):
        super(MNIST_labeled,
              self).__init__(root,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = _3D_to_4(
            normalise_fashionmnist(self.data, mean=(0.2860, ), std=(0.3530, )))

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class MNIST_unlabeled(MNIST_labeled):
    def __init__(self,
                 root,
                 indexs,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=True):
        super(MNIST_unlabeled,
              self).__init__(root,
                             indexs,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        # self.targets = np.array([-1 for i in range(len(self.targets))])


def get_MNIST_data(num_labeled,
                          positive_label_list,
                          root,
                          prior,
                          transform_train=None,
                          transform_val=None):
    base_dataset = torchvision.datasets.MNIST(root,
                                                     train=True,
                                                     download=True)
    # train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split_prior(
    #     base_dataset.targets, num_labeled, positive_label_list, prior)
    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        base_dataset.targets, num_labeled, positive_label_list)
    target_transform = lambda x: 1 if x in positive_label_list else -1

    train_labeled_dataset = MNIST_labeled(
        root,
        train_labeled_idxs,
        train=True,
        transform=transform_train,
        target_transform=target_transform)
    train_unlabeled_dataset = MNIST_unlabeled(
        root,
        train_unlabeled_idxs,
        train=True,
        transform=transform_train,
        target_transform=target_transform)
    val_dataset = MNIST_labeled(root,
                                       train=True,
                                       transform=transform_train,
                                       target_transform=target_transform)
    test_dataset = MNIST_labeled(root,
                                        train=False,
                                        transform=transform_val,
                                        download=True,
                                        target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, prior


class CIFAR10_labeled(torchvision.datasets.CIFAR10):
    def __init__(self,
                 root,
                 indexs=None,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False):
        super(CIFAR10_labeled,
              self).__init__(root,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = transpose(
            normalise(self.data, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class CIFAR10_unlabeled(CIFAR10_labeled):
    def __init__(self,
                 root,
                 indexs,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False):
        super(CIFAR10_unlabeled,
              self).__init__(root,
                             indexs,
                             train=train,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        # self.targets = np.array([-1 for i in range(len(self.targets))])


def get_cifar10_data(num_labeled,
                     positive_label_list,
                     root,
                     prior,
                     transform_train=None,
                     transform_val=None):
    base_dataset = torchvision.datasets.CIFAR10(root,
                                                train=True,
                                                download=True)
    # train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split_prior(
    #     base_dataset.targets, num_labeled, positive_label_list, prior)
    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        base_dataset.targets, num_labeled, positive_label_list)
    target_transform = lambda x: 1 if x in positive_label_list else -1

    train_labeled_dataset = CIFAR10_labeled(root,
                                            train_labeled_idxs,
                                            train=True,
                                            transform=transform_train,
                                            target_transform=target_transform)
    train_unlabeled_dataset = CIFAR10_unlabeled(
        root,
        train_unlabeled_idxs,
        train=True,
        transform=transform_train,
        target_transform=target_transform)
    val_dataset = CIFAR10_labeled(root,
                                  train=True,
                                  transform=transform_train,
                                  target_transform=target_transform)
    test_dataset = CIFAR10_labeled(root,
                                   train=False,
                                   transform=transform_val,
                                   download=True,
                                   target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, prior


class STL10_labeled(torchvision.datasets.STL10):
    def __init__(self,
                 root,
                 indexs=None,
                 split='train+unlabeled',
                 transform=None,
                 target_transform=None,
                 download=False):
        super(STL10_labeled, self).__init__(root,
                                            split=split,
                                            transform=transform,
                                            target_transform=target_transform,
                                            download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.labels = np.array(self.labels)[indexs]
        self.data = transpose(self.data, source='NCHW', target='NHWC')
        self.data = transpose(
            normalise(self.data, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))

    def __getitem__(self, index):
        img, target = self.data[index], self.labels[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class STL10_unlabeled(STL10_labeled):
    def __init__(self,
                 root,
                 indexs,
                 split='train+unlabeled',
                 transform=None,
                 target_transform=None,
                 download=False):
        super(STL10_unlabeled,
              self).__init__(root,
                             indexs,
                             split=split,
                             transform=transform,
                             target_transform=target_transform,
                             download=download)
        # self.labels = np.array([-1 for i in range(len(self.labels))])


def get_stl10_data(num_labeled,
                   positive_label_list,
                   root,
                   prior,
                   transform_train=None,
                   transform_val=None):
    base_dataset = torchvision.datasets.STL10(root,
                                              split='train+unlabeled',
                                              download=True)
    # train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split_prior(
    #     base_dataset.labels, num_labeled, positive_label_list, prior)
    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        base_dataset.labels, num_labeled, positive_label_list)
    target_transform = lambda x: 1 if x in positive_label_list else -1

    train_labeled_dataset = STL10_labeled(root,
                                          train_labeled_idxs,
                                          split='train+unlabeled',
                                          transform=transform_train,
                                          target_transform=target_transform)
    train_unlabeled_dataset = STL10_unlabeled(
        root,
        train_unlabeled_idxs,
        split='train+unlabeled',
        transform=transform_train,
        target_transform=target_transform)
    val_dataset = STL10_labeled(root,
                                split='train',
                                transform=transform_train,
                                target_transform=target_transform)
    test_dataset = STL10_labeled(root,
                                 split='test',
                                 transform=transform_val,
                                 download=True,
                                 target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, prior



class Alzheimer_labeled(Dataset):
    def __init__(self, X_train, Y_train, indexs=None, transform=None, target_transform=None):
        super().__init__()
        self.data = X_train
        self.labels = Y_train
        self.transform = transform
        self.target_transform = target_transform

        if indexs is not None:
            self.data = self.data[indexs]
            self.labels = np.array(self.labels)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.labels[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return(len(self.labels))


class Alzheimer_unlabeled(Alzheimer_labeled):
    def __init__(self, X_train, Y_train, indexs=None, transform=None, target_transform=None):
        super(Alzheimer_unlabeled,
              self).__init__(X_train,
                             Y_train,
                             indexs,
                             transform=transform,
                             target_transform=target_transform)
        self.labels = np.array([-1 for i in range(len(self.labels))])


def get_base_alzheimer(base_dir, txt='train'):
    x = []
    y = []
    txt_path = os.path.join(base_dir, txt+'.txt')
    with open(txt_path, 'r') as f:
        for line in f.readlines():
            if line!=' ':
                line = line.strip()
                file_name, label = line.split('\t')
                if label == '0':
                    file_path = os.path.join(base_dir, txt, 'Negative', file_name)
                else:
                    file_path = os.path.join(base_dir, txt, 'Positive', file_name)

                x.append(file_path)
                y.append(int(label))
    return np.array(x), np.array(y)

transform_Alzheimer = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])]
)

def ad_transform(file_path):
    img = Image.open(file_path)
    img = img.convert('RGB')
    return transform_Alzheimer(img)

def get_alzheimer_data(num_labeled,
                       positive_label_list,
                       root,
                       prior,
                       transform_train=None,
                       transform_val=None
                       ):
    file_path = '../data/Alzheimer_s-Dataset/'

    transform = ad_transform
    target_transform = lambda x: 1 if x in positive_label_list else -1

    x_train, y_train = get_base_alzheimer(file_path, txt='train')
    x_test, y_test = get_base_alzheimer(file_path, txt='test')
    # train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split_prior(
    #     y_train, num_labeled, positive_label_list, prior)
    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        y_train, num_labeled, positive_label_list)
    train_idx = np.append(train_labeled_idxs, train_unlabeled_idxs)
    # val_idxs = np.random.choice(train_idx, 500, replace=False)

    train_labeled_dataset = Alzheimer_labeled(x_train[train_labeled_idxs],
                                              y_train[train_labeled_idxs],
                                              transform=transform,
                                              target_transform=target_transform)

    train_unlabeled_dataset = Alzheimer_unlabeled(x_train[train_unlabeled_idxs],
                                                  y_train[train_unlabeled_idxs],
                                                  transform=transform,
                                                  target_transform=target_transform)

    val_dataset = Alzheimer_labeled(x_train,
                                    y_train,
                                    transform=transform,
                                    target_transform=target_transform)

    test_dataset = Alzheimer_labeled(x_test,
                                     y_test,
                                     transform=transform,
                                     target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, 0.5



class ImageNetteSSL(torchvision.datasets.ImageFolder):
    def __init__(self, root, indexs=None, transform=None, target_transform=None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        if indexs is not None:
            self.data = [self.samples[i] for i in indexs]
            self.targets = np.array([s[1] for s in self.data])
        else:
            self.targets = np.array([s[1] for s in self.samples])

    def __getitem__(self, index):
        path, target = self.data[index]
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return index, img, target

    def __len__(self):
        return len(self.data)


class ImageNette_labeled(torchvision.datasets.ImageFolder):
    def __init__(self, root, indexs=None, transform=None, target_transform=None):
        super(ImageNette_labeled, self).__init__(root, transform=transform,
                                                 target_transform=target_transform)
        if indexs is not None:
            self.data = [self.samples[i] for i in indexs]
            self.targets = np.array([s[1] for s in self.data])
        else:
            self.data = self.samples
            self.targets = np.array([s[1] for s in self.samples])

    def __getitem__(self, index):
        path, target = self.data[index]
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.data)


class ImageNette_unlabeled(ImageNette_labeled):
    def __init__(self, root, indexs=None, transform=None, target_transform=None):
        super(ImageNette_unlabeled, self).__init__(root, indexs, transform=transform,
                                                 target_transform=target_transform)


def get_ImageNette_data(num_labeled,
                     positive_label_list,
                     root,
                     prior,
                     transform_train=None,
                     transform_val=None):

    imagenet_mean = (0.485, 0.456, 0.406)
    imagenet_std = (0.229, 0.224, 0.225)

    crop_ratio = 0.875
    img_size = 64
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(math.floor(int(img_size / crop_ratio))),
        torchvision.transforms.CenterCrop(img_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])

    data_dir = "../data/imagenette2/"

    # base dataset for splitting indices
    base_dataset = torchvision.datasets.ImageFolder(os.path.join(data_dir, "train"))

    train_labeled_idxs, train_unlabeled_idxs, prior = train_val_split(
        base_dataset.targets, num_labeled, positive_label_list)
    target_transform = lambda x: 1 if x in positive_label_list else -1

    train_labeled_dataset = ImageNette_labeled(os.path.join(data_dir, "train"),
                                               indexs=train_labeled_idxs,
                                               transform=transform,
                                               target_transform=target_transform)

    train_unlabeled_dataset = ImageNette_unlabeled(os.path.join(data_dir, "train"),
                                                   indexs=train_unlabeled_idxs,
                                                   transform=transform,
                                                   target_transform=target_transform)

    val_dataset = ImageNette_labeled(os.path.join(data_dir, "train"),
                                     transform=transform,
                                     target_transform=target_transform)

    test_dataset = ImageNette_labeled(os.path.join(data_dir, "val"),
                                      transform=transform,
                                      target_transform=target_transform)

    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset, prior



def get_loaders(train_labeled_dataset,
                train_unlabeled_dataset,
                val_dataset,
                test_dataset,
                batch_size=512):
    p_loader = DataLoader(dataset=train_labeled_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=False)
    x_loader = DataLoader(dataset=train_unlabeled_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=False)
    train_loader = DataLoader(dataset=val_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=False)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             shuffle=False)

    return p_loader, x_loader, train_loader, val_loader, test_loader


def load_image_dataset(dataset_name,
                       num_labeled,
                       batchsize,
                       positive_label_list,
                       root='../data',
                       prior=0.1,
                       with_bias=False,
                       resample_model=""):
    print("==================")
    print("loading data...")
    if dataset_name == "cifar10":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset,
         prior) = get_cifar10_data(num_labeled=num_labeled,
                                   positive_label_list=positive_label_list,
                                   root=root, prior=prior)
    elif dataset_name == "fashionmnist":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset, prior) = get_fashionMNIST_data(
             num_labeled=num_labeled,
             positive_label_list=positive_label_list,
             root=root, prior=prior)
    elif dataset_name == "mnist":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset, prior) = get_MNIST_data(
             num_labeled=num_labeled,
             positive_label_list=positive_label_list,
             root=root, prior=prior)
    elif dataset_name == "stl10":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset,
         prior) = get_stl10_data(num_labeled=num_labeled,
                                 positive_label_list=positive_label_list,
                                 root=root, prior=prior)
    elif dataset_name == "alzheimer":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset, prior) = get_alzheimer_data(
            num_labeled=num_labeled,
            positive_label_list=positive_label_list,
            root=root, prior=prior)
    elif dataset_name == "imagenet":
        (train_labeled_dataset, train_unlabeled_dataset, val_dataset,
         test_dataset, prior) = get_ImageNette_data(
            num_labeled=num_labeled,
            positive_label_list=positive_label_list,
            root=root, prior=prior)
    else:
        raise ValueError("dataset name {} is unknown.".format(dataset_name))

    p_loader, x_loader, train_loader, val_loader, test_loader = get_loaders(
        train_labeled_dataset, train_unlabeled_dataset, val_dataset,
        test_dataset, batchsize)

    # dim = train_labeled_dataset.data.size / len(train_labeled_dataset.data)

    # if dataset_name == "alzheimer":
    #     data_shape = train_labeled_dataset.data[0].shape
    #     dim = data_shape[0] * data_shape[1] * data_shape[2]
    if dataset_name == "imagenet":
        dim = 0
    else:
        dim = train_labeled_dataset.data.size / len(train_labeled_dataset.data)

    print("load data success!")
    print("==================")
    print('    # train data: ', len(x_loader.dataset))
    print('    # labeled train data: ', len(p_loader.dataset))
    print('    # val data: ', len(val_loader.dataset))
    print('    # test data: ', len(test_loader.dataset))
    print('    prior: ', prior)
    print('    dim: ', dim)

    return p_loader, x_loader, train_loader, val_loader, test_loader, train_labeled_dataset, dim, prior



