
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
    
    

def data_high_generator(images, labels, e,flip_rate):
    def torch_bernoulli(p, size):
        return (torch.rand(size) < p).float()
    def torch_xor(a, b):
        return (a-b).abs()
    images = images.reshape((-1, 28, 28))[:, ::2, ::2]
    labels = (labels >4).float()
    labels = torch_xor(labels, torch_bernoulli(flip_rate, len(labels)))
    colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
    images = torch.stack([images, images], dim=1)
    images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
 
    width=5
    height=5
    rows = 2
    cols = 2
    axes=[]
    fig=plt.figure(figsize=(15,10))
    for a in range(8*8):
        b = images[a].float() / 255.
        A = torch.zeros((1,14,14))
        C = torch.cat([b,A])
        #print(C.shape)
        axes.append( fig.add_subplot(8, 8, a+1) )
        subplot_title=("label:"+str(int(labels[a].detach().numpy().copy())))
        axes[-1].set_title(subplot_title)  
        A = C.T 
        B = np.rot90(A, 3)
        D = np.fliplr(B) 
        plt.imshow(D)
    fig.tight_layout()    
    fig.subplots_adjust(wspace=0.5, hspace=0.7)
    plt.title('data1')
    plt.savefig('highimg_e={}_filp={}.png'.format(e ,flip_rate))
    plt.show()
    
    return {
      'images': (images.float() / 255.).cuda(),
      'labels': labels[:, None].cuda()
    }


def data_generator(images, labels,e ,flip_rate):
    def torch_bernoulli(p, size):
        return (torch.rand(size) < p).float()
    def torch_xor(a, b):
        return (a-b).abs()
    images = images.reshape((-1, 28, 28))[:, ::2, ::2]
    #binary_labels=(labels >4).float()
    X = torch_bernoulli(flip_rate, len(labels))
    
    for i in range(len(labels)):
        if X[i] == 1:
            array = np.array([0, 1,2,3,4,5,6,7,8,9])
            delete_index = np.zeros(10)
            delete_index[labels[i]] = 1
            delete_array = array[delete_index ==0]
            T = np.random.randint(9)
            labels[i] = delete_array[T]

    binary_labels=(labels >4).float()
    colors = torch_xor(binary_labels, torch_bernoulli(e, len(binary_labels)))
    images = torch.stack([images, images], dim=1)
    images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
    
    width=5
    height=5
    rows = 2
    cols = 2
    axes=[]
    fig=plt.figure(figsize=(15,10))
    for a in range(8*8):
        b = images[a].float() / 255.
        A = torch.zeros((1,14,14))
        C = torch.cat([b,A])
        #print(C.shape)
        axes.append( fig.add_subplot(8, 8, a+1) )
        subplot_title=("label:"+str(int(labels[a].detach().numpy().copy())))
        axes[-1].set_title(subplot_title)  
        A = C.T 
        B = np.rot90(A, 3)
        D = np.fliplr(B) 
        plt.imshow(D)
    fig.tight_layout()    
    fig.subplots_adjust(wspace=0.5, hspace=0.7)
    plt.title('data1')
    plt.savefig('img_e={}_filp={}.png'.format(e ,flip_rate))
    plt.show()
    return {
      'images': (images.float() / 255.).cuda(),
      'labels': labels[:, None].cuda()
    }
    
def high_data_loader(env_number, envs_array, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Visualization of high data under env={}".format(envs_array))
    for i in range(envs_array.shape[0]):
        data = data_high_generator(sample_size,envs_array[i])
        data['env_labels'] = data['labels'] + 2*i
        
        ax = fig.add_subplot(3,2, i+1)
        ax.set_ylim(-150, 150) 
        ax.set_xlim(-50, 50) 
        w = data['images']*50
        labels = data['labels']
        #print(labels)
        ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
        ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
        
        sample.append(data)
    plt.legend()
    fig.savefig("high_data env={}.png".format(envs_array))
    plt.show()
    return sample


def data_loader(envs, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Training and Test data on e={}.png".format(envs))
    
    #Train data generating
    data_train = data_generator(sample_size,envs)
    ax = fig.add_subplot(1,2, 1)
    ax.title.set_text('Training data')
    ax.set_ylim(-50, 50) 
    ax.set_xlim(-50, 50) 
    w = data_train['images']*50
    labels = data_train['labels']
    ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    ax.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_train)
    
    #Test data generating
    data_test = data_generator(sample_size,-envs)
    bx = fig.add_subplot(1,2, 2)
    bx.title.set_text('Test data')
    bx.set_ylim(-50, 50) 
    bx.set_xlim(-50, 50) 
    w = data_test['images']*50
    labels = data_test['labels']
    bx.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    bx.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    bx.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_test)
    
    
    
    plt.legend()
    fig.savefig("Training and Test data on e={}.png".format(envs))
    plt.show()
    return sample




