import sys, os
import numpy as np
import chainer
from chainer.datasets import get_mnist
import matplotlib
# Disable interactive backend
# matplotlib.use('Agg')
import matplotlib.pyplot as plt
from chainer.backends.cuda import to_cpu,to_gpu

base = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(base, './../'))
from datasets import dataset_envwise as data_env

import ipdb
import copy


'''
Returns the dictionary of data_env, defined for each environments
'''

class ColoredMNISTDataset(object):
    def __init__(self, *args, **kwargs):
        mnist_train, mnist_test = get_mnist(withlabel=True, ndim=2)

        tr_list, ts_list = colored_mnist(train=mnist_train, test=mnist_test,
                                         **kwargs)

        self.train_datasetlist = [data_env.Dataset_envwise(tr_list[j])
                       for j in range(len(tr_list))]

        self.test_datasetlist = [data_env.Dataset_envwise(ts_list[k])
                       for k in range(len(ts_list))]

        self.ts_list = ts_list
        self.tr_list = tr_list
        self.in_ch = 2


def colored_mnist(train=None, test=None, data_type='paper',
                  train_env_num=2, test_env_num=5,
                  tr_env_size=10000, **kwargs):
    # make original MNIST data from chainer to IRM_paper version MNIST
    # flip labels
    # flip colors which is decided based on fliped labels
    # flip rate of colors differs by environments
    # check experiment setting in paper : Invariant Risk Minimization

    images = []
    labels = []
    # step 1 : make binary
    for i in range(len(train)):
        image, label = train[i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))

    for i in range(len(test)):
        image, label = test[i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))
            
    true_labels = np.copy(labels)
    len_data = len(labels)

    # step 2 : flip binary labels with 25%
    for i in range(len_data):
        image = images.pop(0)
        label = labels.pop(0)

        prob = np.random.uniform(0.0, 1.0)
        if prob <= 0.25:
            label = np.int32(1 - label)
        images.append(image)
        labels.append(label)

    # step 3 : color with possibly flipped label
    # label == 0 : red
    # label == 1 : green
    for i in range(len_data):
        image = images.pop(0)
        label = labels.pop(0)
        if label == 0:
            rgb_image = np.concatenate(
                (image, np.zeros(image.shape, dtype=np.float32))).reshape(
                (2, 28, 28))
        if label == 1:
            rgb_image = np.concatenate(
                (np.zeros(image.shape, dtype=np.float32), image)).reshape(
                (2, 28, 28))
 
        images.append(rgb_image[:, ::2, ::2])
        labels.append(label)

    # before step 4 : devide into train_0, train_1, test with different environments.
    tr_size = 50000
    images = np.asarray(images)
    labels = np.asarray(labels)[:, np.newaxis]
    if data_type == 'paper':
        tr0 = {'x': images[:int(tr_size / 2)],
               'y': labels[:int(tr_size / 2)]}
        # ,'env_labels': np.zeros((int(tr_size / 2), 1), dtype='i')}
        tr1 = {'x': images[int(tr_size / 2):tr_size],
               'y': labels[int(tr_size / 2):tr_size]}
        # , 'env_labels': np.ones((int(tr_size / 2), 1), dtype='i')}
        ts = {'x': images[tr_size:], 'y': labels[tr_size:]}
        # ,'env_labels': 2 * np.ones((int(tr_size / 2), 1), dtype='i')}

        # step 4 : flip color regarding environment
        tr0 = flip_color(tr0, env_prob=0.2)
        tr1 = flip_color(tr1, env_prob=0.1)
        ts = flip_color(ts, env_prob=0.9)
        return [tr0, tr1], [ts]
    elif data_type == 'paper_exact_eval':
        tr0 = {'x': images[:int(tr_size / 2)],
               'y': labels[:int(tr_size / 2)]}
        # ,'env_labels': np.zeros((int(tr_size / 2), 1), dtype='i')}
        tr1 = {'x': images[int(tr_size / 2):tr_size],
               'y': labels[int(tr_size / 2):tr_size]}
        # , 'env_labels': np.ones((int(tr_size / 2), 1), dtype='i')}
        tr0 = flip_color(tr0, env_prob=0.2)
        tr1 = flip_color(tr1, env_prob=0.1)

        ts_env_prob_list = np.random.uniform(low=0.0, high=1.0, size=test_env_num)
        ts_list = []
        env_ts_num = 0
        env_tr_num = 2
        for ts_prob in ts_env_prob_list:
            ts_env = {'x': images[tr_size:], 'y': labels[tr_size:]}
            # ,'env_labels': (env_tr_num + env_ts_num) * np.ones((int(tr_size / 2), 1), dtype='i')}
            ts_env = flip_color(ts_env, env_prob=ts_prob)
            ts_list.append(ts_env)
            env_ts_num += 1
        return [tr0, tr1], ts_list
    else:
        tr_env_prob_list = np.random.uniform(low=0.0, high=0.2, size=train_env_num)
        ts_env_prob_list = np.random.uniform(low=0.0, high=1.0, size=test_env_num)
        print('train env', tr_env_prob_list)
        # print('test env', ts_env_prob_list)
        
        if tr_env_size * train_env_num > tr_size:
            tr_list = []
            env_tr_num = 0
            for tr_prob in tr_env_prob_list:
                _index = np.random.choice(tr_size, tr_env_size, replace=False)
                tr_env = {'x': images[_index], 'y': labels[_index]}
                # ,'env_labels': env_tr_num * np.ones((tr_env_size, 1), dtype='i')}
                tr_env = flip_color(tr_env, env_prob=tr_prob)
                tr_list.append(tr_env)
                env_tr_num += 1
        else:
            tr_list = []
            env_tr_num = 0
            _index_list = np.random.choice(tr_size, train_env_num * tr_env_size, replace=False)
            for _i in range(train_env_num):
                tr_prob = tr_env_prob_list[_i]
                _index = _index_list[_i * tr_env_size: (_i+1) * tr_env_size]
                tr_env = {'x': images[_index], 'y': labels[_index]}
                # ,'env_labels': env_tr_num * np.ones((tr_env_size, 1),dtype='i')}
                tr_env = flip_color(tr_env, env_prob=tr_prob)
                tr_list.append(tr_env)
                env_tr_num += 1
            

        ts_list = []
        env_ts_num = 0
        for ts_prob in ts_env_prob_list:
            ts_env = {'x': images[tr_size:], 'y': labels[tr_size:]}
            # ,'env_labels': (env_tr_num + env_ts_num) * np.ones(tr_env_size, dtype='i')}
            ts_env = flip_color(ts_env, env_prob=ts_prob)
            ts_list.append(ts_env)
            env_ts_num += 1
        return tr_list, ts_list


def save_irm_mnist(images, labels, fname):
    # ignore this method
    # for test purpose
    # you can erase this method
    for i in range(10):
        image = images[i]
        label = labels[i]
        rgb_image = np.concatenate(
            (image[:14 * 14], image[14 * 14:], np.zeros((14 * 14)))).reshape(
            (3, 14, 14))
        rgb_image = np.transpose(rgb_image, (1, 2, 0))
        plt.imshow(rgb_image)
        plt.savefig("./" + fname + "colored_" + str(i) + ".png")
        print(label)
        plt.clf()


'''
variable name list is very yabai
'''

def flip_color(list, env_prob):
    # flip color based on environments flip rate : env_prob
    # flip red and green only. not blue pigment
    images = list['x']
    new_images = []
    for i in range(len(images)):
        image = list['x'][i]
        new_image = copy.deepcopy(image) # RGB
        rnd_prob = np.random.uniform(0.0, 1.0)

        #swap Red axis and Green axis with probability
        if rnd_prob <= env_prob:
            new_image[0] = image[1]
            new_image[1] = image[0]

        new_images.append(new_image.flatten())
    new_images = np.asarray(new_images).reshape(-1, 2, 14, 14)
    list['x'] = new_images
    return list
    
    
if __name__ == '__main__':
    
    mnist_train, mnist_test = get_mnist(withlabel=True, ndim=2)
    
    colored_mnist(train=mnist_train, test=mnist_test)
