import os

import random
import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.cm
import matplotlib.pyplot as plt

from itertools import product

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

from sklearn.model_selection import train_test_split
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy

def color_grayscale_arr(arr, red=True):
    """Converts grayscale image to either red or green"""
    assert arr.ndim == 2
    dtype = arr.dtype
    h, w = arr.shape
    arr = np.reshape(arr, [h, w, 1])
    if red:
        arr = np.concatenate([arr,
                                np.zeros((h, w, 2), dtype=dtype)], axis=2)
    else:
        arr = np.concatenate([np.zeros((h, w, 1), dtype=dtype),
                                arr,
                                np.zeros((h, w, 1), dtype=dtype)], axis=2)
    return arr

class ColoredMNIST(WILDSDataset):
    _dataset_name = 'cmnist'

    def __init__(self, version=None, root_dir='', download=True, split_scheme='official',
                invar_str = 0.75, spur_str = 0.9, test_pct = 0.2, val_pct = 0.1, data_seed = None):
        self.invar_str = invar_str
        self.spur_str = spur_str
        self._data_dir = self.initialize_data_dir(root_dir)
        train_mnist = datasets.mnist.MNIST(self._data_dir, train=True, download=download)

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)

        X, Y, G = [], [], []
        for idx, (im, label) in enumerate(train_mnist):
            im_array = np.array(im)

            # Assign a binary label y to the image based on the digit
            binary_label = 0 if label < 5 else 1

            # Flip label with a% probability
            if np.random.uniform() < 1 - self.invar_str:
                binary_label = binary_label ^ 1

            # Color the image either red or green according to its possibly flipped label
            color_red = binary_label == 0

            if np.random.uniform() < 1 - self.spur_str:
                color_red = not color_red

            colored_arr = color_grayscale_arr(im_array, red=color_red)
            binary_attr = int(not color_red)

            X.append(colored_arr)
            Y.append(binary_label)
            G.append(binary_attr)

        # Get the y values
        self._y_array = torch.LongTensor(Y)
        self._y_size = 1
        self._n_classes = 2

        self._metadata_array = torch.stack(
            (torch.LongTensor(G), self._y_array),
            dim=1
        )
        self._metadata_fields = ['background', 'y']
        self._metadata_map = {
            'background': ['0', '1'], 
            'y': ['0', '1']
        }

        self.X = X
        self._original_resolution = (28, 28)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')

        test_idxs = np.random.choice(np.arange(len(train_mnist)), size = int(len(train_mnist) * test_pct), replace = False)
        val_idxs = np.random.choice(np.setdiff1d(np.arange(len(train_mnist)), test_idxs), size = int(len(train_mnist) * val_pct), replace = False)
        self._split_array = np.zeros((len(train_mnist), 1))
        self._split_array[val_idxs] = 1
        self._split_array[test_idxs] = 2
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['background', 'y']))

        if data_seed is not None:
            np.random.set_state(state)

        super().__init__(self._data_dir, split_scheme)


    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        metric = Accuracy(prediction_fn=prediction_fn)

        results, results_str = self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

        return results, results_str

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       return Image.fromarray(self.X[idx])


class SpuriousCIFAR10(WILDSDataset):
    _dataset_name = 'spur_cifar10'

    def __init__(self, version=None, root_dir='', download=True, split_scheme='official',
                invar_str = 1., spur_str = 0.95, test_pct = 0.2, val_pct = 0.1, B = 0.5, data_seed = None):
        self.invar_str = invar_str
        self.spur_str = spur_str
        self.B = B

        self._data_dir = self.initialize_data_dir(root_dir)

        cifar_train = datasets.cifar.CIFAR10(root = self._data_dir, train = True, download = download)
        cifar_test = datasets.cifar.CIFAR10(root = self._data_dir, train = False, download = download)

        train_X, train_Y = np.array(cifar_train.data), np.array(cifar_train.targets)
        test_X, test_Y = np.array(cifar_test.data), np.array(cifar_test.targets)

        classes = np.sort(np.unique(train_Y))
        dim = train_X.shape[1]

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)
            r_state = random.getstate()
            random.seed(data_seed)

        configs = list(product([lambda x: 0.5 + 0.5*x, lambda x: 0.5 - 0.5 * x], repeat = 4))
        random.shuffle(configs)
        config_mapping = configs[:10]

        for ds in [train_Y, test_Y]:
            flip_inds = np.random.randint(0, len(ds), size = int(len(ds) * (1 - self.invar_str)))
            for cls in classes:
                cls_inds = np.intersect1d(flip_inds, (ds == cls).nonzero())
                ds[cls_inds] = np.random.choice(np.delete(classes, cls), size = len(cls_inds), replace = True)

        G = []
        for X, Y in ((train_X, train_Y), (test_X, test_Y)):
            spu_config = np.random.random(len(X)) >= (1-self.spur_str)
            G.append(spu_config.astype(int))
            for i in range(len(X)):
                y = Y[i]
                config = config_mapping[y] if spu_config[i] else random.choice(config_mapping[:y] + config_mapping[y+1:])    
                X[i, int(dim/2), : , 0] = config[0](self.B) # horizontal
                for ch in range(3):
                    X[i, :, int(dim/2), ch] = config[ch + 1](self.B)

        self.X = np.concatenate((train_X, test_X))

        # Get the y values
        self._y_array = torch.from_numpy(np.concatenate((train_Y, test_Y))).long()
        self._y_size = 1
        self._n_classes = len(classes)

        self._metadata_array = torch.stack(
            (torch.from_numpy(np.concatenate(G)).long(), self._y_array),
            dim=1
        )
        self._metadata_fields = ['background', 'y']
        self._metadata_map = {
            'background': ['0', '1'], 
            'y': cifar_train.classes
        }

        self._original_resolution = (32, 32)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')

        val_idxs = np.random.choice(np.arange(len(train_X)), size = int(len(train_X) * val_pct), replace = False)
        self._split_array = np.zeros((len(train_X) + len(test_X), 1))
        self._split_array[val_idxs] = 1
        self._split_array[len(train_X):] = 2
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['background', 'y']))

        if data_seed is not None:
            np.random.set_state(state)
            random.setstate(r_state)

        super().__init__(self._data_dir, split_scheme)


    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        metric = Accuracy(prediction_fn=prediction_fn)

        results, results_str = self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

        return results, results_str

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       return Image.fromarray(self.X[idx])
