import torchvision.transforms as transforms
import torchvision
import os
import numpy as np
import torch

from imbalanced_datasets.mipll_datasets import MIPLL_Dataset, Gold_Dataset
from imbalanced_datasets.imbalanced_datasets_utils import gen_imbalanced_data

mnist_img_transform = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(
    (0.1307,), (0.3081,)
  )
])

def get_mipll_dataset(args, func, preimage_dict=None):
    print("==> Loading local data copy in the long-tailed setup")
    data_file = "{ds}_imb_{it}_{imf}_sd{sd}_ni={ni}_fc={fc}_M={m}.npy".format(
        ds=args.dataset,
        it=args.imb_type,
        imf=args.imb_ratio,
        sd=args.seed,
        ni=args.size,
        fc=func.__name__,
        m=args.M,
    )
    save_path = os.path.join(args.data_dir, data_file)
    if not os.path.exists(save_path):
        train_dataset = torchvision.datasets.MNIST(
            root="../../data/" + "MNIST",
            train=True,
            download=True,
            transform=transforms.ToTensor(),
        )
        data, labels = (
            np.array(train_dataset.data),
            np.array(train_dataset.targets)
        )

        train_data, train_labels = gen_imbalanced_data(
            data, labels, args.num_class, args.imb_type, args.imb_factor
        )

        # Sample args.size images (x)
        rnd = [
            np.random.choice(len(train_data), size=args.size)
            for _ in range(args.M)
        ]
        # Get the x images
        xs = [train_data[r] for r in rnd]
        # Get the gold labels for x images
        ys = [train_labels[r] for r in rnd]

        s = list()
        if preimage_dict is not None:
            p = list()
        else:
            p = None

        # Get the partial label
        for i in range(args.size):
            s.append(func([ys[j][i].item() for j in range(args.M)]))
            if preimage_dict is not None:
                p.append(preimage_dict[s[-1]])
        data_dict = {"x": xs, "y": ys, "s": s, "p": p}
        save_path = os.path.join(args.data_dir, data_file)
        os.makedirs(args.data_dir, exist_ok=True)
        with open(save_path, "wb") as f:
            np.save(f, data_dict)
        print("local data saved at ", save_path)

    data_dict = np.load(save_path, allow_pickle=True).item()
    xs, ys, ss, p = data_dict["x"], data_dict["y"], data_dict["s"], data_dict["p"]
    ys = [y.numpy() for y in ys]
    
    return xs, ys, ss, p

def load_arithmetic_n(args, preimage_dict=None):

    if args.dataset == "mmax":
        func = max
    elif args.dataset == "msum":
        func = sum

    xs, ys, ss, p = get_mipll_dataset(args, func, preimage_dict=preimage_dict)

    ys = [torch.from_numpy(y) for y in ys]
    digits = list(range(10))
    # Be careful. If you choose not many partial training samples, then it is likely that not all classes are sampled,
    # especially, when specifying argument 'exp'.
    # Hence, this line of code might return an exception.

    train_label_cnt = None
    for y in ys:
        # print(train_label_cnt)
        # print(y, torch.unique(y, sorted=True, return_counts=True))
        # print()
        if train_label_cnt == None:
            train_label_cnt = dict()
            # train_label_cnt = torch.unique(y, sorted=True, return_counts=True)[-1]
            y_unique, y_counts = torch.unique(y, sorted=True, return_counts=True)
            for i in range(len(y_unique)):
                train_label_cnt[y_unique[i].item()] = y_counts[i].item()
        else:
            # train_label_cnt += torch.unique(y, sorted=True, return_counts=True)[-1]
            y_unique, y_counts = torch.unique(y, sorted=True, return_counts=True)
            for i in range(len(y_unique)):
                if y_unique[i].item() in train_label_cnt:
                    train_label_cnt[y_unique[i].item()] += y_counts[i].item()
                else:
                    train_label_cnt[y_unique[i].item()] = y_counts
    # train_label_cnt is used to intialize Acc-shot object

    train_label_cnt = torch.tensor([train_label_cnt[d] if d in train_label_cnt else 0 for d in digits])
    # print("train label cnt: ", train_label_cnt)
    # exit()

    if p is not None:
        print(len(p))

    train_pll_dataset = MIPLL_Dataset(
            xs,
            [y.float() for y in ys],
            ss,
            args.size,
            mnist_img_transform,
            preimage=p,
        )
    train_mipll_dataset_loader = torch.utils.data.DataLoader(
            dataset=train_pll_dataset,
            collate_fn=MIPLL_Dataset.collate_fn,
            batch_size=args.batch_size_train,
            shuffle=True,
        )

    test_dataset = torchvision.datasets.MNIST(
        root="../../data/" + "MNIST", train=False, download=True, transform=mnist_img_transform)
    
    if args.imb_test:
        print("Using imbalanced test set (training distribution = test distribution).")
        test_data, test_labels = np.array(test_dataset.data), np.array(
            test_dataset.targets
        )
        test_data, test_labels = gen_imbalanced_data(
            test_data, test_labels, args.num_class, args.imb_type, args.imb_factor
        )
        test_label_ratio = np.bincount(test_labels)
        test_label_ratio = torch.tensor(test_label_ratio / test_label_ratio.sum())
        print("test label ratio: ", test_label_ratio)
        test_dataset = Gold_Dataset(test_data, test_labels, mnist_img_transform)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size_test,
        shuffle=True,
        num_workers=4,
    )

    return (
            train_mipll_dataset_loader,
            test_loader,
            train_label_cnt,
        )
