import sys, os
import numpy as np
import chainer
from chainer.datasets import get_mnist
import matplotlib
# Disable interactive backend
# matplotlib.use('Agg')
import matplotlib.pyplot as plt
from chainer.backends.cuda import to_cpu, to_gpu
from PIL import Image

base = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(base, './../'))
from datasets import dataset_envwise as data_env

import ipdb
import copy


'''
Returns the dictionary of data_env, defined for each environments
'''

class ExtendedColoredMNISTDataset(object):
    def __init__(self, *args, **kwargs):
        mnist_train, mnist_test = get_mnist(withlabel=True, ndim=2, scale=255.0)

        tr_list, ts_list = extended_colored_mnist(
            train=mnist_train, test=mnist_test, **kwargs)

        self.train_datasetlist = [data_env.Dataset_envwise(tr_list[j])
                       for j in range(len(tr_list))]

        self.test_datasetlist = [data_env.Dataset_envwise(ts_list[k])
                       for k in range(len(ts_list))]

        self.ts_list = ts_list
        self.tr_list = tr_list
        

class ExtendedColoredMNISTDatasetTest(object):
    def __init__(self, *args, **kwargs):
        mnist_train, mnist_test = get_mnist(withlabel=True, ndim=2, scale=255.0)

        tr_list, ts_list = extended_colored_mnist_test(
            train=mnist_train, test=mnist_test, **kwargs)

        self.train_datasetlist = [data_env.Dataset_envwise(tr_list[j])
                       for j in range(len(tr_list))]

        self.test_datasetlist = [data_env.Dataset_envwise(ts_list[k])
                       for k in range(len(ts_list))]

        self.ts_list = ts_list
        self.tr_list = tr_list


def extended_colored_mnist(train=None, test=None,
                           tr_size = 50000, tr_env_size=25000,
                           train_env_num=2, test_env_num=5,
                           data_type='default',
                           ch12_prob_env1=0.1,
                           ch12_prob_env2=0.2,
                           ch3_prob_env1=0.3,
                           ch3_prob_env2=0.4,
                           ch3_train_lower=0.4,
                           ch3_train_upper=0.6,
                           label_flip_rate_0=0.25,
                           label_flip_rate_1=0.65,
                           backcolor_type=False,
                           image_L=14,
                           post_prc='default',
                           **kwargs):
    
    if data_type == 'default':
        train_env_num = 4
        tr_env_size = 12500
    
    label_flip_rate = [label_flip_rate_0, label_flip_rate_1]
    ### train & test data をimagesとlabelsに一旦まとめる
    images = []
    labels = []
    for _i in range(len(train)):
        image, label = train[_i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))

    for _i in range(len(test)):
        image, label = test[_i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))

    len_data = len(labels)
    
    for _ in range(len_data):
        image = images.pop(0)
        label = labels.pop(0)
        _pad = np.zeros(image.shape, dtype=np.float32)
        rgb_image = np.concatenate((image, _pad, _pad)).reshape((3, 28, 28))
        A = rgb_image.astype("i").transpose(1, 2, 0)
        pilImg = Image.fromarray(np.uint8(A))
        pilImg = pilImg.resize((image_L, image_L), Image.BICUBIC)#LANCZOS)
        rgb_image = np.copy(np.asarray(pilImg).astype('f').transpose(2, 0, 1)) / 255.0
        images.append(rgb_image)
        labels.append(label)
            
    images = np.asarray(images)
    labels = np.asarray(labels)[:, np.newaxis]

    ts_env_size = len_data - tr_size
    tr_list = []
    for _i in range(train_env_num):
        tr_env = {'x': np.zeros((tr_env_size, 3, image_L, image_L)).astype('f'),
                  'y': np.zeros((tr_env_size, 1)).astype(np.int32)}
        tr_list.append(tr_env)
    
    ts_list = []
    for _i in range(test_env_num):
        ts_env = {'x': np.zeros((ts_env_size, 3, image_L, image_L)).astype('f'),
                  'y': np.zeros((ts_env_size, 1)).astype(np.int32)}
        ts_list.append(ts_env)
    
    if data_type == 'default':
        tr_list[0]['x'] = np.copy(images[:tr_env_size])
        tr_list[1]['x'] = np.copy(images[tr_env_size:2*tr_env_size])
        tr_list[2]['x'] = np.copy(images[2*tr_env_size:3*tr_env_size])
        tr_list[3]['x'] = np.copy(images[3*tr_env_size:])
        tr_list[0]['y'] = np.copy(labels[:tr_env_size])
        tr_list[1]['y'] = np.copy(labels[tr_env_size:2*tr_env_size])
        tr_list[2]['y'] = np.copy(labels[2*tr_env_size:3*tr_env_size])
        tr_list[3]['y'] = np.copy(labels[3*tr_env_size:])
    else:
        if tr_env_size * train_env_num > tr_size:
            for _i in range(train_env_num):
                _index = np.random.choice(tr_size, tr_env_size, replace=False)
                tr_list[_i]['x'] = np.copy(images[_index])
                tr_list[_i]['y'] = np.copy(labels[_index])
        else:
            _index_list = np.random.choice(tr_size, train_env_num * tr_env_size, replace=False)
            for _i in range(train_env_num):
                _index = _index_list[_i * tr_env_size: (_i+1) * tr_env_size]
                tr_list[_i]['x'] = np.copy(images[_index])
                tr_list[_i]['y'] = np.copy(labels[_index])
            
    
    for _i in range(test_env_num):
        ts_list[_i]['x'] = np.copy(images[tr_size:])
        ts_list[_i]['y'] = np.copy(labels[tr_size:])

    if data_type == 'default':
        zero_prob_list_train_ch3 = np.array([ch3_prob_env1, ch3_prob_env2, ch3_prob_env1, ch3_prob_env2])
        zero_prob_list_test_ch3 = np.random.uniform(0.0, 1.0, test_env_num)
    else:
        zero_prob_list_train_ch3 = np.random.uniform(ch3_train_lower,
                                                     ch3_train_upper,
                                                     train_env_num)
        zero_prob_list_test_ch3 = np.random.uniform(0.0, 1.0, test_env_num)
        print(zero_prob_list_train_ch3)
    
    for _i in range(train_env_num):
        for _j in range(tr_env_size):
            ch3_prob = np.random.uniform(0.0, 1.0)
            label_prob = np.random.uniform(0.0, 1.0)
            if ch3_prob >= zero_prob_list_train_ch3[_i]:
                if backcolor_type:
                    tr_list[_i]['x'][_j, 2, :, :] = 1.0
                else:
                    tr_list[_i]['x'][_j, 2, :, :] = np.copy(tr_list[_i]['x'][_j, 0, :, :])
                if label_prob <= label_flip_rate[1]:
                    tr_list[_i]['y'][_j] = 1 - tr_list[_i]['y'][_j]
            else:
                if label_prob <= label_flip_rate[0]:
                    tr_list[_i]['y'][_j] = 1 - tr_list[_i]['y'][_j]
    
    for _i in range(test_env_num):
        for _j in range(ts_env_size):
            ch3_prob = np.random.uniform(0.0, 1.0)
            label_prob = np.random.uniform(0.0, 1.0)
            if ch3_prob >= zero_prob_list_test_ch3[_i]:
                if backcolor_type:
                    ts_list[_i]['x'][_j, 2, :, :] = 1.0
                else:
                    ts_list[_i]['x'][_j, 2, :, :] = np.copy(ts_list[_i]['x'][_j, 0, :, :])
                if label_prob <= label_flip_rate[1]:
                    ts_list[_i]['y'][_j] = 1 - ts_list[_i]['y'][_j]
            else:
                if label_prob <= label_flip_rate[0]:
                    ts_list[_i]['y'][_j] = 1 - ts_list[_i]['y'][_j]

    if data_type == 'default':
        tr_env_prob_list = np.array([ch12_prob_env1, ch12_prob_env2, ch12_prob_env2, ch12_prob_env1])
        ts_env_prob_list = np.linspace(0.1, 0.9, test_env_num)
    else:
        tr_env_prob_list = np.random.uniform(low=0.0, high=0.2, size=train_env_num)
        ts_env_prob_list = np.linspace(0.1, 0.9, test_env_num)

    for _i in range(train_env_num):
        tr_prob = tr_env_prob_list[_i]
        tr_env = tr_list.pop(0)
        tr_env = flip_color(tr_env, env_prob=tr_prob)
        tr_list.append(tr_env)
        
    for _i in range(test_env_num):
        ts_prob = ts_env_prob_list[_i]
        ts_env = ts_list.pop(0)
        ts_env = flip_color(ts_env, env_prob=ts_prob)
        ts_list.append(ts_env)
        
    if post_prc == 'oracle':
        for _i in range(train_env_num):
            tr_prob = tr_env_prob_list[_i]
            tr_env = tr_list.pop(0)
            tr_env = generate_oracle(tr_env)
            tr_list.append(tr_env)
            
        for _i in range(test_env_num):
            ts_prob = ts_env_prob_list[_i]
            ts_env = ts_list.pop(0)
            ts_env = generate_oracle(ts_env)
            ts_list.append(ts_env)
    elif post_prc == 'figonly':
        for _i in range(train_env_num):
            tr_prob = tr_env_prob_list[_i]
            tr_env = tr_list.pop(0)
            tr_env = generate_figonly(tr_env)
            tr_list.append(tr_env)
            
        for _i in range(test_env_num):
            ts_prob = ts_env_prob_list[_i]
            ts_env = ts_list.pop(0)
            ts_env = generate_figonly(ts_env)
            ts_list.append(ts_env)
    else:
        if post_prc != 'default':
            raise NotImplementedError()
            
    return tr_list, ts_list
    
    
def flip_color(env_list, env_prob):
    # images = env_list['x']
    # labels = env_list['y']
    for _i in range(len(env_list['x'])):
        # print(env_list['y'][_i])
        # print(np.mean(env_list['x'][_i][0]))
        # print(np.mean(env_list['x'][_i][1]))
        ## labelに合わせて背景を変える
        if env_list['y'][_i][0] == 1:
            # print('flip')
            env_list['x'][_i] = env_list['x'][_i][[1, 0, 2], :]
        # print(np.mean(env_list['x'][_i][0]))
        # print(np.mean(env_list['x'][_i][1]))
        # print('-----')
        
        rnd_prob = np.random.uniform(0.0, 1.0)
        if rnd_prob <= env_prob:
            env_list['x'][_i] = env_list['x'][_i][[1, 0, 2], :]
    return env_list
    
def generate_oracle(env_list):
    for _i in range(len(env_list['x'])):
        A = np.sum(env_list['x'][_i], axis=(1, 2))
        if A[0] <= 0.1:
            env_list['x'][_i][0] = env_list['x'][_i][1]
        if A[1] <= 0.1:
            env_list['x'][_i][1] = env_list['x'][_i][0]
    return env_list
    
def generate_figonly(env_list):
    for _i in range(len(env_list['x'])):
        A = np.sum(env_list['x'][_i], axis=(1, 2))
        if A[0] <= 0.1:
            env_list['x'][_i][0] = env_list['x'][_i][1]
            if A[2] <= 0.1:
                env_list['x'][_i][2] = env_list['x'][_i][1]
        if A[1] <= 0.1:
            env_list['x'][_i][1] = env_list['x'][_i][0]
            if A[2] <= 0.1:
                env_list['x'][_i][2] = env_list['x'][_i][0]
    return env_list
    
    
def extended_colored_mnist_test(train=None, test=None,
                                tr_size = 50000, tr_env_size=25000,
                                train_env_num=2, test_env_num=5,
                                data_type='default',
                                ch12_prob_list=[0.1, 0.2],
                                ch3_prob_list=[0.75, 0.25],
                                label_flip_rate_0=0.25,
                                label_flip_rate_1=0.65,
                                backcolor_type=True,
                                image_L=14,
                                **kwargs):
    
    label_flip_rate = [label_flip_rate_0, label_flip_rate_1]
    images = []
    labels = []
    for _i in range(len(train)):
        image, label = train[_i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))

    for _i in range(len(test)):
        image, label = test[_i]
        images.append(image)
        if label <= 4:
            labels.append(np.int32(0))
        else:
            labels.append(np.int32(1))

    len_data = len(labels)
    
    for _ in range(len_data):
        image = images.pop(0)
        label = labels.pop(0)
        _pad = np.zeros(image.shape, dtype=np.float32)
        rgb_image = np.concatenate((image, _pad, _pad)).reshape((3, 28, 28))
        A = rgb_image.astype("i").transpose(1, 2, 0)
        pilImg = Image.fromarray(np.uint8(A))
        pilImg = pilImg.resize((image_L, image_L), Image.BICUBIC)#LANCZOS)
        rgb_image = np.copy(np.asarray(pilImg).astype('f').transpose(2, 0, 1)) / 255.0
        images.append(rgb_image)
        labels.append(label)
            
    images = np.asarray(images)
    labels = np.asarray(labels)[:, np.newaxis]
    
    ts_env_size = len_data - tr_size
    tr_list = []
    for _i in range(train_env_num):
        tr_env = {'x': np.zeros((tr_env_size, 3, image_L, image_L)).astype('f'),
                  'y': np.zeros((tr_env_size, 1)).astype(np.int32)}
        tr_list.append(tr_env)
    
    ts_list = []
    for _i in range(test_env_num):
        ts_env = {'x': np.zeros((ts_env_size, 3, image_L, image_L)).astype('f'),
                  'y': np.zeros((ts_env_size, 1)).astype(np.int32)}
        ts_list.append(ts_env)
    
    if data_type == 'default':
        tr_list[0]['x'] = np.copy(images[:tr_env_size])
        tr_list[1]['x'] = np.copy(images[tr_env_size:tr_size])
        tr_list[0]['y'] = np.copy(labels[:tr_env_size])
        tr_list[1]['y'] = np.copy(labels[tr_env_size:tr_size])
    else:
        for _i in range(train_env_num):
            _index = np.random.choice(tr_size, tr_env_size, replace=False)
            tr_list[_i]['x'] = np.copy(images[_index])
            tr_list[_i]['y'] = np.copy(labels[_index])
    
    for _i in range(test_env_num):
        ts_list[_i]['x'] = np.copy(images[tr_size:])
        ts_list[_i]['y'] = np.copy(labels[tr_size:])
    
    if data_type == 'default':
        train_env_num = 2
        tr_env_size = 25000
        zero_prob_list_train_ch3 = np.array(ch3_prob_list)
        zero_prob_list_test_ch3 = np.random.uniform(0.0, 1.0, test_env_num)
    else:
        zero_prob_list_train_ch3 = np.random.uniform(0.0, 1.0, train_env_num)
        zero_prob_list_test_ch3 = np.random.uniform(0.0, 1.0, test_env_num)
    
    for _i in range(train_env_num):
        for _j in range(tr_env_size):
            ch3_prob = np.random.uniform(0.0, 1.0)
            label_prob = np.random.uniform(0.0, 1.0)
            if ch3_prob >= zero_prob_list_train_ch3[_i]:
                if backcolor_type:
                    tr_list[_i]['x'][_j, 2, :, :] = 1.0
                else:
                    tr_list[_i]['x'][_j, 2, :, :] = np.copy(tr_list[_i]['x'][_j, 0, :, :])
                if label_prob <= label_flip_rate[1]:
                    tr_list[_i]['y'][_j] = 1 - tr_list[_i]['y'][_j]
            else:
                if label_prob <= label_flip_rate[0]:
                    tr_list[_i]['y'][_j] = 1 - tr_list[_i]['y'][_j]
    
    for _i in range(test_env_num):
        for _j in range(ts_env_size):
            ch3_prob = np.random.uniform(0.0, 1.0)
            label_prob = np.random.uniform(0.0, 1.0)
            if ch3_prob >= zero_prob_list_test_ch3[_i]:
                if backcolor_type:
                    ts_list[_i]['x'][_j, 2, :, :] = 1.0
                else:
                    ts_list[_i]['x'][_j, 2, :, :] = np.copy(ts_list[_i]['x'][_j, 0, :, :])
                if label_prob <= label_flip_rate[1]:
                    ts_list[_i]['y'][_j] = 1 - ts_list[_i]['y'][_j]
            else:
                if label_prob <= label_flip_rate[0]:
                    ts_list[_i]['y'][_j] = 1 - ts_list[_i]['y'][_j]
    
    return tr_list, ts_list


if __name__ == '__main__':
    
    mnist_train, mnist_test = get_mnist(withlabel=True, ndim=2, scale=255.0)
    
    tr_list, ts_list = extended_colored_mnist(train=mnist_train, test=mnist_test,
                           ch12_prob_list=[0.1, 0.2],
                           ch3_prob_list=[0.75, 0.25],
                           label_flip_rate=[0.25, 0.65],
                           data_type='original',
                           post_prc='fig_only')
    
    print(len(ts_list[1]['y'] == 0))
    print(len(ts_list[1]['y'] == 1))
