import os
from pathlib import Path
import pdb
from typing import Any, Tuple
import pickle

import pandas as pd
import PIL
import numpy as np

import torch
from torch.utils.data import Dataset
import torchvision.datasets
import torchvision.transforms as transforms
from two_step_zoo.datasets.supervised_dataset import SupervisedDataset

from tqdm import tqdm
from ..utils import deterministic_shuffle

class CenterCropLongEdge(object):
  """Crops the given PIL Image on the long edge.
  Args:
      size (sequence or int): Desired output size of the crop. If size is an
          int instead of sequence like (h, w), a square crop (size, size) is
          made.
  """
  def __call__(self, img):
    """
    Args:
        img (PIL Image): Image to be cropped.
    Returns:
        PIL Image: Cropped image.
    """
    return transforms.functional.center_crop(img, min(img.size))

  def __repr__(self):
    return self.__class__.__name__

class CelebA(Dataset):
    '''
    CelebA PyTorch dataset
    The built-in PyTorch dataset for CelebA is outdated.
    '''

    def __init__(self, root: str, role: str = "train"):
        self.root = Path(root)
        self.role = role
        
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])

        celeb_path = lambda x: self.root / x

        role_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        splits_df = pd.read_csv(celeb_path("list_eval_partition.csv"))
        self.filename = splits_df[splits_df["partition"] == role_map[self.role]]["image_id"].tolist()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img_path = (self.root / "img_align_celeba" /
                    "img_align_celeba" / self.filename[index])
        X = PIL.Image.open(img_path)
        X = self.transform(X)

        return X, 0

    def __len__(self) -> int:
        return len(self.filename)
    
    def to(self, device):
        return self

def imagefolder_to_supervised(imagefolder_dset, dataset_role, data_root="", size=128, dataset_name="imagenet", transforms=None):
    
    #TODO: clean up caching
    path = f'image_net_data_{dataset_role}_{size}_{data_root.split("/")[-1]}.pickle'

    if not os.path.exists(path):
        images,labels = [],[]
        for im,lab in tqdm(imagefolder_dset, desc=f"Preprocessing imagenet, split={dataset_role}"):
            images.append(im)
            labels.append(lab)

        images = torch.stack(images).to(dtype=torch.get_default_dtype())
        labels = torch.tensor(labels).long()
        
        with open(f'image_net_data_{dataset_role}_{size}_{data_root.split("/")[-1]}.pickle', 'wb') as handle: 
            pickle.dump({"images": images, "labels": labels}, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    else:
        with open(f'image_net_data_{dataset_role}_{size}_{data_root.split("/")[-1]}.pickle', 'rb') as handle:
            object = pickle.load(handle)
    
        images = object["images"]
        labels = object["labels"]

    return SupervisedDataset(dataset_name, dataset_role, images, labels, transforms)

#TODO: make test fraction configurable
def get_imagenet_datasets(data_root, valid_fraction, class_ind, test_fraction=0.2):

    #TODO: configurable data transforms
    transform_list = [CenterCropLongEdge(), transforms.Resize(128),transforms.ToTensor()]
   
    dset = torchvision.datasets.ImageFolder(
                data_root,
                transforms.Compose(transform_list))
    
    if class_ind != -1:
        cls_inds = [i for i,x in enumerate(dset.targets) if x == class_ind]

        dset = torch.utils.data.Subset(dset, cls_inds)

    # Get train, test, val splits
    N = len(dset)
    all_inds = np.arange(N)
    all_inds = deterministic_shuffle(all_inds)

    #TOOD: argument
    if "golden" in data_root:
        all_inds = all_inds
    else:
        all_inds = all_inds[:50000]
    N = len(all_inds)

    test_len = round(N*test_fraction)
    valid_len = round(N*valid_fraction)

    test_inds = all_inds[:test_len]
    valid_inds = all_inds[test_len:(valid_len+test_len)]
    train_inds = all_inds[(valid_len+test_len):]

    train_dset = imagefolder_to_supervised(torch.utils.data.Subset(dset, train_inds), data_root=data_root, dataset_role="train", transforms=transforms)
    test_dset = imagefolder_to_supervised(torch.utils.data.Subset(dset, test_inds), data_root=data_root, dataset_role="test", transforms=transforms)

    if valid_len == 0:
        valid_dset = test_dset
    else:
        valid_dset = imagefolder_to_supervised(torch.utils.data.Subset(dset, valid_inds), data_root=data_root, dataset_role="valid", transforms=transforms)

    print(f"Dataset lengths: train={len(train_dset)}, val={len(valid_dset)}, test={len(test_dset)}")

    return train_dset, valid_dset, test_dset

def get_image_datasets_by_class(dataset_name, data_root, valid_fraction, class_ind):
    data_dir = os.path.join(data_root, dataset_name)

    if dataset_name == "celeba":
        assert class_ind == -1

        # valid_fraction ignored
        data_class = CelebA

    elif "imagenet" in dataset_name:

        return get_imagenet_datasets(data_root, valid_fraction, class_ind)

    else:
        raise ValueError(f"Unknown dataset {dataset_name}")
    
    train_dset = data_class(root=data_dir, role="train")
    valid_dset = data_class(root=data_dir, role="valid")
    test_dset = data_class(root=data_dir, role="test")

    return train_dset, valid_dset, test_dset


def image_tensors_to_dataset(dataset_name, dataset_role, images, labels, transforms):
    images = images.to(dtype=torch.get_default_dtype())
    labels = labels.long()
    return SupervisedDataset(dataset_name, dataset_role, images, labels, transforms)

def visualize_dataset(labels, images, class_ind):
    from torchvision.utils import save_image
    for idx,image in enumerate(images):
        save_image(image.to(torch.float32), "./trash/" + str(class_ind) + "_junk_" + str(idx)+".png")
        if idx >= 10: return

# Returns tuple of form `(images, labels)`. Both are uint8 tensors.
# `images` has shape `(nimages, nchannels, nrows, ncols)`, and has
# entries in {0, ..., 255}
def get_raw_image_tensors(dataset_name, train, data_root, class_ind):
    data_dir = os.path.join(data_root, dataset_name)

    if dataset_name == "cifar10":
        dataset = torchvision.datasets.CIFAR10(root=data_dir, train=train, download=True)
        images = torch.tensor(dataset.data).permute((0, 3, 1, 2))
        labels = torch.tensor(dataset.targets)
    
    elif dataset_name == "cifar100":
        dataset = torchvision.datasets.CIFAR100(root=data_dir, train=train, download=True)
        images = torch.tensor(dataset.data).permute((0, 3, 1, 2))
        labels = torch.tensor(dataset.targets)

    elif dataset_name == "svhn":
        dataset = torchvision.datasets.SVHN(root=data_dir, split="train" if train else "test", download=True)
        images = torch.tensor(dataset.data)
        labels = torch.tensor(dataset.labels)

    elif dataset_name in ["mnist", "fashion-mnist"]:
        dataset_class = {
            "mnist": torchvision.datasets.MNIST,
            "fashion-mnist": torchvision.datasets.FashionMNIST
        }[dataset_name]
        dataset = dataset_class(root=data_dir, train=train, download=True)
        images = dataset.data.unsqueeze(1)
        labels = dataset.targets

    else:
        raise ValueError(f"Unknown dataset {dataset_name}")

    if class_ind != -1:
        print("Restricting dataset to class:", class_ind)
        class_idxs = labels == class_ind
        labels = labels[class_idxs]
        images = images[class_idxs]
        visualize_dataset(labels, images, class_ind)

    return images.to(torch.uint8), labels.to(torch.uint8)

def get_torchvision_datasets(dataset_name, data_root, valid_fraction, class_ind, transforms):
    images, labels = get_raw_image_tensors(dataset_name, train=True, data_root=data_root, class_ind=class_ind)

    perm = torch.arange(images.shape[0])
    perm = deterministic_shuffle(perm)
    print("Torchvision dataset first inds of perm:", perm[:5])

    shuffled_images = images[perm]
    shuffled_labels = labels[perm]

    valid_size = int(valid_fraction * images.shape[0])
    valid_images = shuffled_images[:valid_size]
    valid_labels = shuffled_labels[:valid_size]
    train_images = shuffled_images[valid_size:]
    train_labels = shuffled_labels[valid_size:]

    train_dset = image_tensors_to_dataset(dataset_name, "train", train_images, train_labels, transforms)
    valid_dset = image_tensors_to_dataset(dataset_name, "valid", valid_images, valid_labels, transforms)
    
    test_images, test_labels = get_raw_image_tensors(dataset_name, train=False, data_root=data_root, class_ind=class_ind)
    test_dset = image_tensors_to_dataset(dataset_name, "test", test_images, test_labels, transforms)

    return train_dset, valid_dset, test_dset

def get_image_datasets(dataset_name, data_root, make_valid_dset, valid_fraction, class_ind=-1, transforms=None):
    if not make_valid_dset: valid_fraction = 0
  
    torchvision_datasets = ["mnist", "fashion-mnist", "svhn", "cifar10", "cifar100"]

    get_datasets_fn = get_torchvision_datasets if dataset_name in torchvision_datasets else get_image_datasets_by_class
    
    return get_datasets_fn(dataset_name, data_root, valid_fraction, class_ind, transforms)
