
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
    
    
def data_high_generator(samplesize,env):
    
    labels = np.random.randint(0, 3, samplesize)
    x = np.random.normal(loc=0, scale=1.0, size=samplesize)
    x[labels==0] = np.random.normal(loc=0, scale=10.0, size=x[labels==0].shape)
    x[labels==1] = np.random.normal(loc=30, scale=10.0, size=x[labels==1].shape)
    x[labels==2] = np.random.normal(loc=-30, scale=10.0, size=x[labels==2].shape)
    
    

    z = np.random.normal(loc=0, scale=1.0, size=samplesize)
    z[labels==0] = np.random.normal(loc=env, scale=10.0, size=z[labels==0].shape)
    z[labels==1] = np.random.normal(loc=-5*env, scale=10.0, size=z[labels==1].shape)
    z[labels==2] = np.random.normal(loc=-env, scale=10.0, size=z[labels==2].shape)
    labels=torch.from_numpy(labels[:, None]).float()
    labels = (labels > 0).float()
    
    w= np.concatenate(([x.reshape(samplesize,1),z.reshape(samplesize,1)]),axis=1)
    
        
    return {
                  'images': torch.from_numpy(w.astype(np.float32))/150,
                  'labels': labels
                }


def data_generator(samplesize,env):  #Best passible=84%
    labels = np.random.randint(0, 3, samplesize)
    x = np.random.normal(loc=0, scale=1.0, size=samplesize)
    x[labels==0] = np.random.normal(loc=0, scale=10.0, size=x[labels==0].shape)
    x[labels==1] = np.random.normal(loc=30, scale=10.0, size=x[labels==1].shape)
    x[labels==2] = np.random.normal(loc=-30, scale=10.0, size=x[labels==2].shape)

    
    z = np.random.normal(loc=0, scale=1.0, size=samplesize)
    z[labels==0] = np.random.normal(loc=env, scale=10.0, size=z[labels==0].shape)
    z[labels==1] = np.random.normal(loc=-5*env, scale=10.0, size=z[labels==1].shape)
    z[labels==2] = np.random.normal(loc=-env, scale=10.0, size=z[labels==2].shape)
    labels=torch.from_numpy(labels[:, None]).float()
    
    w= np.concatenate(([x.reshape(samplesize,1),z.reshape(samplesize,1)]),axis=1)
    
        
    return {
                  'images': torch.from_numpy(w.astype(np.float32))/150,
                  'labels': labels
                }

def data_loader(envs, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Training and Test data on e={}.png".format(envs))
    
    #Train data generating
    data_train = data_generator(sample_size,envs)
    ax = fig.add_subplot(1,2, 1)
    ax.title.set_text('Training data')
    ax.set_ylim(-300, 300) 
    ax.set_xlim(-50, 50) 
    w = data_train['images']*150
    labels = data_train['labels']
    ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    ax.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_train)
    
    #Test data generating
    data_test = data_generator(sample_size,-envs)
    bx = fig.add_subplot(1,2, 2)
    bx.title.set_text('Test data')
    bx.set_ylim(-300, 300) 
    bx.set_xlim(-50, 50) 
    w = data_test['images']*150
    labels = data_test['labels']
    bx.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    bx.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    bx.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    sample.append(data_test)
    
    
    
    plt.legend()
    #fig.savefig("Training and Test data on e={}.png".format(envs))
    plt.show()
    return sample


def high_data_loader(env_number, envs_array, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Visualization of high data under env={}".format(envs_array))
    for i in range(envs_array.shape[0]):
        data = data_high_generator(sample_size,envs_array[i])
        data['env_labels'] = data['labels'] + 2*i
        
        ax = fig.add_subplot(3,2, i+1)
        ax.set_ylim(-300, 300) 
        ax.set_xlim(-50, 50) 
        w = data['images']*150
        labels = data['labels']
        #print(labels)
        ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
        ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
        
        sample.append(data)
    plt.legend()
    #fig.savefig("high_data env={}.png".format(envs_array))
    plt.show()
    return sample



