# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import torch
from PIL import Image, ImageFile
from torchvision import transforms as T
from torch.utils.data import TensorDataset
from torchvision.datasets import MNIST, ImageFolder
from torchvision.transforms.functional import rotate
ImageFile.LOAD_TRUNCATED_IMAGES = True
from typing import Any, Callable, cast, Dict, List, Optional, Tuple


import itertools
import torch.utils.data as data


DATASETS = [
    # Debug
    "Debug28",
    "Debug224",
    # Small images
    "ColoredMNIST",
    "RotatedMNIST",
    # Big images
    "Vlcs",
    "PACS",
    "Vlcs",
    "TerraIncognita",
    "DomainNet",
    # TeacherEnsemble
    "TeacherEnsemblePacs"
]


def get_dataset_class(dataset_name):
    """Return the dataset class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]


def num_environments(dataset_name):
    return len(get_dataset_class(dataset_name).ENVIRONMENTS)


class MultipleDomainDataset:
    N_STEPS = 15001  # Default, subclasses may override
    CHECKPOINT_FREQ = 100  # Default, subclasses may override
    N_WORKERS = 2  # Default, subclasses may override
    ENVIRONMENTS = None  # Subclasses should override
    INPUT_SHAPE = None  # Subclasses should override

    def __getitem__(self, index):
        """
        Return: sub-dataset for specific domain
        """
        return self.datasets[index]

    def __len__(self):
        """
        Return: # of sub-datasets
        """
        return len(self.datasets)


class Debug(MultipleDomainDataset):
    def __init__(self, root):
        super().__init__()
        self.input_shape = self.INPUT_SHAPE
        self.num_classes = 2
        self.datasets = []
        for _ in [0, 1, 2]:
            self.datasets.append(
                TensorDataset(
                    torch.randn(16, *self.INPUT_SHAPE),
                    torch.randint(0, self.num_classes, (16,)),
                )
            )


class Debug28(Debug):
    INPUT_SHAPE = (3, 28, 28)
    ENVIRONMENTS = ["0", "1", "2"]


class Debug224(Debug):
    INPUT_SHAPE = (3, 224, 224)
    ENVIRONMENTS = ["0", "1", "2"]


class MultipleEnvironmentMNIST(MultipleDomainDataset):
    def __init__(self, root, environments, dataset_transform, input_shape, num_classes):
        """
        Args:
            root: root dir for saving MNIST dataset
            environments: env properties for each dataset
            dataset_transform: dataset generator function
        """
        super().__init__()
        if root is None:
            raise ValueError("Data directory not specified!")

        original_dataset_tr = MNIST(root, train=True, download=True)
        original_dataset_te = MNIST(root, train=False, download=True)

        original_images = torch.cat((original_dataset_tr.data, original_dataset_te.data))

        original_labels = torch.cat((original_dataset_tr.targets, original_dataset_te.targets))

        shuffle = torch.randperm(len(original_images))

        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []
        self.environments = environments

        for i in range(len(environments)):
            images = original_images[i :: len(environments)]
            labels = original_labels[i :: len(environments)]
            self.datasets.append(dataset_transform(images, labels, environments[i]))

        self.input_shape = input_shape
        self.num_classes = num_classes


class ColoredMNIST(MultipleEnvironmentMNIST):
    ENVIRONMENTS = ["+90%", "+80%", "-90%"]

    def __init__(self, root):
        super(ColoredMNIST, self).__init__(
            root,
            [0.1, 0.2, 0.9],
            self.color_dataset,
            (2, 28, 28),
            2,
        )

    def color_dataset(self, images, labels, environment):
        # # Subsample 2x for computational convenience
        # images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit
        labels = (labels < 5).float()
        # Flip label with probability 0.25
        labels = self.torch_xor_(labels, self.torch_bernoulli_(0.25, len(labels)))

        # Assign a color based on the label; flip the color with probability e
        colors = self.torch_xor_(labels, self.torch_bernoulli_(environment, len(labels)))
        images = torch.stack([images, images], dim=1)
        # Apply the color to the image by zeroing out the other color channel
        images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0

        x = images.float().div_(255.0)
        y = labels.view(-1).long()

        return TensorDataset(x, y)

    def torch_bernoulli_(self, p, size):
        return (torch.rand(size) < p).float()

    def torch_xor_(self, a, b):
        return (a - b).abs()


class RotatedMNIST(MultipleEnvironmentMNIST):
    ENVIRONMENTS = ["0", "15", "30", "45", "60", "75"]

    def __init__(self, root):
        super(RotatedMNIST, self).__init__(
            root,
            [0, 15, 30, 45, 60, 75],
            self.rotate_dataset,
            (1, 28, 28),
            10,
        )

    def rotate_dataset(self, images, labels, angle):
        rotation = T.Compose(
            [
                T.ToPILImage(),
                T.Lambda(lambda x: rotate(x, angle, fill=(0,), resample=Image.BICUBIC)),
                T.ToTensor(),
            ]
        )

        x = torch.zeros(len(images), 1, 28, 28)
        for i in range(len(images)):
            x[i] = rotation(images[i])

        y = labels.view(-1)

        return TensorDataset(x, y)


class MultipleEnvironmentImageFolder(MultipleDomainDataset):
    def __init__(self, root):
        super().__init__()
        environments = [f.name for f in os.scandir(root) if f.is_dir()]
        environments = sorted(environments)
        self.environments = environments

        self.datasets = []
        for environment in environments:
            path = os.path.join(root, environment)
            env_dataset = ImageFolder(path)

            self.datasets.append(env_dataset)

        self.input_shape = (3, 224, 224)
        self.num_classes = len(self.datasets[-1].classes)
        
class WarpImageFolder(ImageFolder):
    def __init__(self, root, is_in_environments_commom=False,env_tensor=None):
        self.folder=ImageFolder(root)
        self.is_in_environments_commom = is_in_environments_commom

        if is_in_environments_commom:
            self.env_tensor = env_tensor
            self.ensemble_size = len(env_tensor)
        
    def __len__(self) -> int:
        return len(self.folder)
    
    def __getitem__(self, index):
        item =  self.folder[index]
        tensors = []
        if self.is_in_environments_commom:
            for i in range(self.ensemble_size):
                tensors.append(self.env_tensor[i][0][index])
            tensors = torch.stack(tensors, dim=0)
            return item[0], item[1], tensors
        else:
            return item[0], item[1]
        

def load_tensors_from_directories(root1, test_domain, sub_domain):
    domains = [f.name for f in os.scandir(root1) if f.is_dir()]

    all_tensors = []

    original_dirs = []

    if sub_domain != test_domain:
        # original_dir = os.path.join(root1, sub_domain)
        # sub_o_domains = [f.name for f in os.scandir(original_dir) if f.is_dir()]
        domains.remove(test_domain)
        domains.remove(sub_domain) 
        
        for domain in domains:
            sub_original_dir = os.path.join(root1, domain, sub_domain)
            original_dirs.append(sub_original_dir)
    else:
        print("something wrong!")
    
    print(original_dirs)

    for dir in original_dirs:
        tensors = []
        tensor_path = os.path.join(dir, 'tensor.pt')
        tensor = torch.load(tensor_path)
        tensors.append(tensor)
        all_tensors.append(tensors)

    return all_tensors



class Ensemble3EnvironmentImageFolder(MultipleDomainDataset):
    def __init__(self, root, pre_root, test_envs):
        super().__init__()
        environments = [f.name for f in os.scandir(root) if f.is_dir()]
        environments = sorted(environments)
        test_name = environments[test_envs[0]]
        environments_pred = environments.copy()
        environments_pred.pop(test_envs[0])

        # Find environments that exist in both A and B
        environments_commom = set(environments) & set(environments_pred)
        
        self.environments = environments
        self.datasets = []
        self.env_tensor = []
        for environment in self.environments:
            path = os.path.join(root, environment)
            
            if environment in environments_commom:
                tensors = load_tensors_from_directories(pre_root, test_name, environment)
                self.env_tensor = tensors
                env_dataset = WarpImageFolder(path,True,self.env_tensor)
            else:
                env_dataset = WarpImageFolder(path,False)

            self.datasets.append(env_dataset)
                
        self.input_shape = (3, 224, 224)
        self.num_classes = len(self.env_tensor[0][0][0])


class PACS(MultipleEnvironmentImageFolder):
    CHECKPOINT_FREQ = 200
    N_STEPS = 5001
    ENVIRONMENTS = ["A", "C", "P", "S"]

    def __init__(self, root):
        self.dir = os.path.join(root, "PACS/")
        super().__init__(self.dir)
        

class DomainNet(MultipleEnvironmentImageFolder):
    CHECKPOINT_FREQ = 1000
    N_STEPS = 15001
    ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"]

    def __init__(self, root):
        self.dir = os.path.join(root, "domain_net/")
        super().__init__(self.dir)


class VLCS(MultipleEnvironmentImageFolder):
    CHECKPOINT_FREQ = 1000
    ENVIRONMENTS = ["A", "C", "P", "R"]

    def __init__(self, root):
        self.dir = os.path.join(root, "db/")
        super().__init__(self.dir)

        
class TerraIncognita(MultipleEnvironmentImageFolder):
    CHECKPOINT_FREQ = 1000
    ENVIRONMENTS = ["L100", "L38", "L43", "L46"]

    def __init__(self, root):
        self.dir = os.path.join(root, "db/")
        super().__init__(self.dir)

class DomED_PACS(Ensemble3EnvironmentImageFolder):
    CHECKPOINT_FREQ = 200
    N_STEPS = 5001
    ENVIRONMENTS = ["A", "P", "C","S"]

    def __init__(self, root, pre_root, test_envs):
        self.dir1 = os.path.join(root, "PACS/")
        self.dir2 = pre_root
        super().__init__(self.dir1, self.dir2, test_envs)




