import pandas as pd
import torch
import numpy as np
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet, get_autoencoder_dataset
from onlinedatasets.models import TorchRewardsModel, AutoEncoder, TorchRewardsModelMultilayer
import pandas as pd
import random
import IPython
import os
import copy

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


from algorithms import train_autoencoder, train_simple_regression
import pickle

from autoencoder_test import get_data, regression_analysis

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE ", device)


def get_autoencoded_CMAP_datasets():

    df, df1, df2, df3 = get_data()


    raw_dataset = df.values.astype('float64')
    raw_dataset1 = df1.values.astype('float64')
    raw_dataset2 = df2.values.astype('float64')
    raw_dataset3 = df3.values.astype('float64')

    #responses_Y = np.mean(raw_dataset, 1)*20-10
    responses_Y = np.mean(raw_dataset, 1).squeeze()
    responses_Y = pd.DataFrame(responses_Y)

    responses_Y1 = np.mean(raw_dataset1, 1).squeeze()
    responses_Y1 = pd.DataFrame(responses_Y1)
    responses_Y2 = np.mean(raw_dataset2, 1).squeeze()
    responses_Y2 = pd.DataFrame(responses_Y2)
    responses_Y3 = np.mean(raw_dataset3, 1).squeeze()
    responses_Y3 = pd.DataFrame(responses_Y3)


    encoded_supervised_dataset = pickle.load( open("../CMAP/cmap_processed/encoded_trt_sh.VCAP.978genes.p", "rb"))
    encoded_supervised_dataset1 = pickle.load(open("../CMAP/cmap_processed/encoded_trt_sh.HA1E.978genes.p", "rb") )
    encoded_supervised_dataset2 = pickle.load(open("../CMAP/cmap_processed/encoded_trt_sh.MCF7.978genes.p", "rb"))
    encoded_supervised_dataset3 = pickle.load(open("../CMAP/cmap_processed/encoded_trt_sh.A375.978genes.p", "rb") )


    IPython.embed()

    # encoded_supervised_dataset1 = DataSet(autoencoded_df1, responses_Y1)
    # encoded_supervised_dataset2 = DataSet(autoencoded_df2, responses_Y2)
    # encoded_supervised_dataset3 = DataSet(autoencoded_df3, responses_Y3)

    return encoded_supervised_dataset, encoded_supervised_dataset1, encoded_supervised_dataset2, encoded_supervised_dataset3

def main1():


    dataset_names = ["VCAP", "HA1E", "MCF7", "A375"]

    encoded_supervised_dataset, encoded_supervised_dataset1, encoded_supervised_dataset2, encoded_supervised_dataset3 = get_autoencoded_CMAP_datasets()
    encoded_supervised_datasets = [encoded_supervised_dataset, encoded_supervised_dataset1, encoded_supervised_dataset2, encoded_supervised_dataset3]
    dim = encoded_supervised_dataset.dataset.values.shape[1]

    num_steps = 20000 #12000
    batch_size = 200
    logging_frequency = 10

    representation_layer_sizes = [ 300, 40]



    num_experiments = 5
    data = dict([])
    for i in range(len(dataset_names)):
        dataset_data = []
        for j in range(num_experiments):
            reward_model = TorchRewardsModelMultilayer(dim = dim, representation_layer_sizes = representation_layer_sizes, 
                activation_type = 'relu', batch_norm = True, device = device )

            reward_model, test_loss_list_regression = train_simple_regression(reward_model, encoded_supervised_datasets[i], num_steps, batch_size, verbose = True, logging_frequency = logging_frequency)
            print("Finished {}, experiment {}".format(dataset_names[i], j+1))
            dataset_data.append(test_loss_list_regression)
        data[dataset_names[i]] = dataset_data

    #IPython.embed()
    #import pickle

    pickle.dump(data, open("./autoencoder_results/autoencoder_paper_plots_data.p", "wb"))

    colors = ["red", "blue", "black", "green"]

    ##### PLOT REGRESSION LOSS LIST ####
    plt.title("Regression")
    plt.xlabel("Num Batches - Batch Size {}".format(batch_size))
    plt.ylabel("Regression Loss")
    for dataset_name, color in zip(dataset_names, colors):  
        #IPython.embed()
        #raise ValueError("aslfdkm") 
        mean_losses = np.mean(data[dataset_name], axis = 0)
        std_losses  = np.std(data[dataset_name], axis = 0)

        plt.plot(logging_frequency*(np.arange(len(mean_losses))+1), mean_losses, linewidth = 3, color = color, label = dataset_name)
        plt.fill_between( logging_frequency*(np.arange(len(mean_losses))+1), mean_losses - .5*std_losses, mean_losses + .5*std_losses, alpha = .2 )

    plt.ylim([0,.015])


    plt.yscale('log')

    #if max(test_loss_list_regression) > 3:
    #    plt.ylim(0, 3)
    plt.legend(loc = "upper right")
    plt.savefig("./autoencoder_results/regression_loss_all.png")

    #IPython.embed()
    plt.close('all')

    IPython.embed()


def main():
    get_autoencoded_CMAP_datasets()

    
    dataset_names = ["VCAP", "HA1E", "MCF7", "A375"]
    num_steps = 20000 #12000
    batch_size = 200
    logging_frequency = 10

    representation_layer_sizes = [ 300, 40]


    averaging_window = 50

    data = pickle.load(open("./autoencoder_results/autoencoder_paper_plots_data.p", "rb"))

    colors = ["red", "blue", "black", "green"]

    ##### PLOT REGRESSION LOSS LIST ####
    plt.title("Regression")
    plt.xlabel("Num Batches")# - Batch Size {}".format(batch_size))
    plt.ylabel("Regression Loss")
    for dataset_name, color in zip(dataset_names, colors):  
        #IPython.embed()
        #raise ValueError("aslfdkm") 

        mean_losses = np.mean(data[dataset_name], axis = 0)
        
        std_losses  = np.std(data[dataset_name], axis = 0)
        num_datapoints = len(mean_losses)

        #IPython.embed()
        #raise ValueError("Asflkm")
        mean_losses = mean_losses.reshape((int(num_datapoints/averaging_window), averaging_window))
        mean_losses = np.mean(mean_losses, axis = 1)
        std_losses = std_losses.reshape((int(num_datapoints/averaging_window), averaging_window))
        std_losses = np.mean(std_losses, axis = 1)


        #print('mean shapes ', mean_losses.shape)
        #IPython.embed()
        
        plt.plot(logging_frequency*averaging_window*(np.arange(len(mean_losses))+1), mean_losses, linewidth = 3, color = color, label = dataset_name)
        plt.fill_between( logging_frequency*averaging_window*(np.arange(len(mean_losses))+1), mean_losses - .5*std_losses, mean_losses + .5*std_losses, alpha = .2 , color = color)



    plt.yscale('log')
    plt.ylim([0, .001])

    #if max(test_loss_list_regression) > 3:
    #    plt.ylim(0, 3)
    plt.legend(loc = "upper right")
    plt.savefig("./autoencoder_results/regression_loss_all.png")

    #IPython.embed()
    plt.close('all')

    IPython.embed()



if __name__ == "__main__":
    main()
