import pandas as pd
import torch
import numpy as np
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet, get_autoencoder_dataset
from onlinedatasets.models import TorchRewardsModel, AutoEncoder, TorchRewardsModelMultilayer
import pandas as pd
import random
import IPython
import os
import copy

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


from algorithms import train_autoencoder, train_simple_regression
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE ", device)





def regression_analysis(dim, encoded_supervised_dataset, num_steps, batch_size, logging_frequency, data_index, representation_layer_sizes = [100, 10]):

    # reward_model = TorchRewardsModelMultilayer(dim = dim, representation_layer_sizes = representation_layer_sizes, 
    #     activation_type = 'leaky_relu', batch_norm = True, device = device )
    reward_model = TorchRewardsModelMultilayer(dim = dim, representation_layer_sizes = representation_layer_sizes, 
        activation_type = 'relu', batch_norm = True, device = device )

    reward_model, test_loss_list_regression = train_simple_regression(reward_model, encoded_supervised_dataset, num_steps, batch_size, verbose = True, logging_frequency = logging_frequency)


    ##### PLOT REGRESSION LOSS LIST ####
    plt.title("Regression")
    plt.xlabel("Num Batches")
    plt.ylabel("Regression Loss")
    plt.plot(logging_frequency*np.arange(len(test_loss_list_regression)), test_loss_list_regression, linewidth = 3, color = "red", label = "Batch Size {}".format(batch_size))

    plt.yscale('log')

    #if max(test_loss_list_regression) > 3:
    #    plt.ylim(0, 3)
    plt.legend(loc = "upper right")
    plt.savefig("./autoencoder_results/regression_loss{}.png".format(data_index))

    #IPython.embed()
    plt.close('all')




    ###### RUN REGRESSION ANALYSIS ##### 
    predictions = reward_model.get_reward(encoded_supervised_dataset.dataset.values)

    ##### RESPONSES Y HISTOGRAM #######
    plt.title("Histogram Comparison" )
    plt.xlabel("Y responses")
    plt.ylabel("Counts")
    plt.hist( encoded_supervised_dataset.labels.values, bins = 50, color = "blue", label = "Y values")
    plt.hist( predictions.detach().cpu().numpy(), bins = 50, color = "red", label = "Predictions")
    plt.legend(loc = "upper right")
    plt.savefig("./autoencoder_results/y_vs_prediction_histogram{}.png".format(data_index))
    plt.close('all')

    ##################################
    print("predictions        ", predictions[:100])
    print("original responses ", np.squeeze(encoded_supervised_dataset.labels.values[:100]))




def get_data():
    df = pd.read_csv (r'../CMAP/trt_sh.VCAP.978genes.csv', index_col = 0)

    df1 = pd.read_csv (r'../CMAP/trt_sh.HA1E.978genes.csv', index_col = 0)
    df2 = pd.read_csv (r'../CMAP/trt_sh.MCF7.978genes.csv', index_col = 0)
    df3 = pd.read_csv (r'../CMAP/trt_sh.A375.978genes.csv', index_col = 0)


    index_0 = set(df.index)
    index_1 = set(df1.index)
    index_2 = set(df2.index)
    index_3 = set(df3.index)

    sig_info = pd.read_csv(r'../CMAP/GSE92742_Broad_LINCS_sig_info.txt.gz', sep='\t', dtype=str)



    ### Load sig_id and pert_id maps
    df_pert_map = sig_info[ sig_info['sig_id'].isin(list(df.index))][['sig_id','pert_id']]
    df_pert_map1 = sig_info[ sig_info['sig_id'].isin(list(df1.index))][['sig_id','pert_id']]
    df_pert_map2 = sig_info[ sig_info['sig_id'].isin(list(df2.index))][['sig_id','pert_id']]
    df_pert_map3 = sig_info[ sig_info['sig_id'].isin(list(df3.index))][['sig_id','pert_id']]


    # dictionary_pert_map = dict([  (  list(df_pert_map['sig_id'])[i],   list(df_pert_map['pert_id'])[i] ) for i in range(df_pert_map.shape[0])])
    # dictionary_pert_map1 = dict([  (  list(df_pert_map1['sig_id'])[i],   list(df_pert_map1['pert_id'])[i] ) for i in range(df_pert_map1.shape[0])])
    # dictionary_pert_map2 = dict([  (  list(df_pert_map2['sig_id'])[i],   list(df_pert_map2['pert_id'])[i] ) for i in range(df_pert_map2.shape[0])])
    # dictionary_pert_map3 = dict([  (  list(df_pert_map3['sig_id'])[i],   list(df_pert_map3['pert_id'])[i] ) for i in range(df_pert_map3.shape[0])])


    # df_replace_pert = df.replace(to_replace = dictionary_pert_map)
    # df_replace_pert1 = df1.replace(to_replace = dictionary_pert_map1)
    # df_replace_pert2 = df2.replace(to_replace = dictionary_pert_map2)
    # df_replace_pert3 = df3.replace(to_replace = dictionary_pert_map3)


    #IPython.embed()



    ## find pert_ids that appear a single time 
    pert_ids_frequencies = df_pert_map['pert_id'].value_counts()
    pert_ids_frequencies1 = df_pert_map1['pert_id'].value_counts()
    pert_ids_frequencies2 = df_pert_map2['pert_id'].value_counts()
    pert_ids_frequencies3 = df_pert_map3['pert_id'].value_counts()


    pert_ids_multiple_df = pert_ids_frequencies > 1
    pert_ids_multiple_df1 = pert_ids_frequencies1 > 1
    pert_ids_multiple_df2 = pert_ids_frequencies2 > 1
    pert_ids_multiple_df3 = pert_ids_frequencies3 > 1




    pert_ids_multiple = list(pert_ids_multiple_df[pert_ids_multiple_df == True].index)
    pert_ids_unique = list(pert_ids_multiple_df[pert_ids_multiple_df == False].index)

    pert_ids_multiple1 = list(pert_ids_multiple_df1[pert_ids_multiple_df1 == True].index)
    pert_ids_unique1 = list(pert_ids_multiple_df1[pert_ids_multiple_df1 == False].index)


    pert_ids_multiple2 = list(pert_ids_multiple_df2[pert_ids_multiple_df2 == True].index)
    pert_ids_unique2 = list(pert_ids_multiple_df2[pert_ids_multiple_df2 == False].index)

    pert_ids_multiple3 = list(pert_ids_multiple_df3[pert_ids_multiple_df3 == True].index)
    pert_ids_unique3 = list(pert_ids_multiple_df3[pert_ids_multiple_df3 == False].index)


    ### Get the sig ids of the pert_ids mapping to a single sig_id
    sig_ids_unique = list(df_pert_map[df_pert_map['pert_id'].isin(list(pert_ids_unique))]['sig_id'])


    sig_ids_unique1 = list(df_pert_map1[df_pert_map1['pert_id'].isin(list(pert_ids_unique1))]['sig_id'])
    sig_ids_unique2 = list(df_pert_map2[df_pert_map2['pert_id'].isin(list(pert_ids_unique2))]['sig_id'])
    sig_ids_unique3 = list(df_pert_map3[df_pert_map3['pert_id'].isin(list(pert_ids_unique3))]['sig_id'])



    df_trimmed = df.loc[sig_ids_unique]
    df_trimmed1 = df1.loc[sig_ids_unique1]
    df_trimmed2 = df2.loc[sig_ids_unique2]
    df_trimmed3 = df3.loc[sig_ids_unique3]



    rebased_df_pert_map = df_pert_map.set_index('sig_id')
    rebased_df_pert_map1 = df_pert_map1.set_index('sig_id')
    rebased_df_pert_map2 = df_pert_map2.set_index('sig_id')
    rebased_df_pert_map3 = df_pert_map3.set_index('sig_id')



    pert_ids_unique_new_index = list(rebased_df_pert_map.loc[df_trimmed.index]['pert_id'])
    pert_ids_unique_new_index1 = list(rebased_df_pert_map1.loc[df_trimmed1.index]['pert_id'])
    pert_ids_unique_new_index2 = list(rebased_df_pert_map2.loc[df_trimmed2.index]['pert_id'])
    pert_ids_unique_new_index3 = list(rebased_df_pert_map3.loc[df_trimmed3.index]['pert_id'])




    df_trimmed_rebased = df_trimmed.reset_index()
    df_trimmed_rebased1 = df_trimmed1.reset_index()
    df_trimmed_rebased2 = df_trimmed2.reset_index()
    df_trimmed_rebased3 = df_trimmed3.reset_index()


    old_columns = df_trimmed_rebased.columns
    old_columns1 = df_trimmed_rebased1.columns
    old_columns2 = df_trimmed_rebased2.columns
    old_columns3 = df_trimmed_rebased3.columns

    df_trimmed_rebased_values = df_trimmed.reset_index().values
    df_trimmed_rebased_values1 = df_trimmed1.reset_index().values
    df_trimmed_rebased_values2 = df_trimmed2.reset_index().values
    df_trimmed_rebased_values3 = df_trimmed3.reset_index().values






    df_trimmed_rebased_values[:, 0] = pert_ids_unique_new_index
    df_trimmed_rebased_values1[:, 0] = pert_ids_unique_new_index1
    df_trimmed_rebased_values2[:, 0] = pert_ids_unique_new_index2
    df_trimmed_rebased_values3[:, 0] = pert_ids_unique_new_index3




    df_trimmed_rebased = pd.DataFrame(df_trimmed_rebased_values)
    df_trimmed_rebased.columns = old_columns
    df_trimmed_rebased = df_trimmed_rebased.set_index('cid')

    df_trimmed_rebased1 = pd.DataFrame(df_trimmed_rebased_values1)
    df_trimmed_rebased1.columns  = old_columns1
    df_trimmed_rebased1 = df_trimmed_rebased1.set_index('cid')

    df_trimmed_rebased2 = pd.DataFrame(df_trimmed_rebased_values2)
    df_trimmed_rebased2.columns  = old_columns2
    df_trimmed_rebased2 = df_trimmed_rebased2.set_index('cid')

    df_trimmed_rebased3 = pd.DataFrame(df_trimmed_rebased_values3)
    df_trimmed_rebased3.columns  = old_columns3
    df_trimmed_rebased3 = df_trimmed_rebased3.set_index('cid')



    intersection1 = list(set(df_trimmed_rebased.index).intersection(set(df_trimmed_rebased1.index)))
    intersection2 = list(set(df_trimmed_rebased.index).intersection(set(df_trimmed_rebased2.index)))
    intersection3 = list(set(df_trimmed_rebased.index).intersection(set(df_trimmed_rebased3.index)))





    df_trimmed_intersection1 = df_trimmed_rebased1.loc[intersection1]
    df_trimmed_intersection2 = df_trimmed_rebased2.loc[intersection2]
    df_trimmed_intersection3 = df_trimmed_rebased3.loc[intersection3]

    return df_trimmed_rebased, df_trimmed_intersection1, df_trimmed_intersection2, df_trimmed_intersection3







def main():

    df, df1, df2, df3 = get_data()


    pickle.dump( df,  open("../CMAP/cmap_processed/filtered_trt_sh.VCAP.978genes.p", "wb"))
    pickle.dump( df1, open("../CMAP/cmap_processed/filtered_trt_sh.HA1E.978genes.p", "wb") )
    pickle.dump( df2, open("../CMAP/cmap_processed/filtered_trt_sh.MCF7.978genes.p", "wb"))
    pickle.dump( df3, open("../CMAP/cmap_processed/filtered_trt_sh.A375.978genes.p", "wb") )





    IPython.embed()






    #raw_dataset = pd.DataFrame(df.values.astype('float64'))

    #representation_layer_sizes = [1000, 100, 20]
    representation_layer_sizes = [1500, 300, 100]


    #raw_dataset = (raw_dataset.values+10)/20.0
    raw_dataset = df.values.astype('float64')
    raw_dataset1 = df1.values.astype('float64')
    raw_dataset2 = df2.values.astype('float64')
    raw_dataset3 = df3.values.astype('float64')

    #responses_Y = np.mean(raw_dataset, 1)*20-10
    responses_Y = np.mean(raw_dataset, 1)
    responses_Y = pd.DataFrame(responses_Y)

    responses_Y1 = np.mean(raw_dataset1, 1)
    responses_Y1 = pd.DataFrame(responses_Y1)
    responses_Y2 = np.mean(raw_dataset2, 1)
    responses_Y2 = pd.DataFrame(responses_Y2)
    responses_Y3 = np.mean(raw_dataset3, 1)
    responses_Y3 = pd.DataFrame(responses_Y3)







    ##### RESPONSES Y HISTOGRAM #######


    plt.title("Histogram")

    plt.xlabel("Y responses")
    plt.ylabel("Counts")
    plt.hist( responses_Y.values, bins = 50)
    plt.legend(loc = "upper right")

    plt.savefig("./autoencoder_results/y_histogram.png")


    plt.close('all')


    ##################################



    raw_dataset_df = pd.DataFrame(raw_dataset.astype('float64'))



    # def init_weights(m):
    #     if isinstance(m, torch.nn.Linear):
    #         # torch.nn.init.xavier_uniform(m.weight)
    #         # m.bias.data.fill_(0.01)
    #         torch.nn.init.zeros_(m.weight)



    unsupervised_dataset = DataSetUnsupervised(raw_dataset_df)
    autoencoder = AutoEncoder(dim = raw_dataset_df.shape[1], encoder_representation_layer_sizes = representation_layer_sizes, 
        activation_type = 'leaky_relu', batch_norm = False, device = device )
    # train_autoencoder()
    # autoencoder.encoder.apply(init_weights)
    # autoencoder.decoder.apply(init_weights)



    autoencoder_train_batch_size = 200
    num_steps = 100000
    logging_frequency = 1000


    def autoencoder_analysis(autoencoder, unsupervised_dataset, num_steps = num_steps, batch_size = autoencoder_train_batch_size, logging_frequency = logging_frequency):	
        eval_batch = get_batches(unsupervised_dataset, 1000000000)
        #autoencoder = AutoEncoder(random_init = random_init, MLP = MLP, dim = unsupervised_dataset.dimension, representation_layer_size = representation_layer_size )
        initial_eval_loss = autoencoder.get_loss(eval_batch)
        print("Initial loss ", initial_eval_loss)
        autoencoder, test_loss_list = train_autoencoder( autoencoder, unsupervised_dataset, num_steps, batch_size, logging_frequency = logging_frequency)
        #autoencoder = train_autoencoder(autoencoder, unsupervised_dataset, num_steps = num_steps, batch_size = batch_size)
        eval_loss = autoencoder.get_loss(eval_batch)
        print("Eval loss ", eval_loss)
        return autoencoder, test_loss_list


    autoencoder, test_loss_list_autoencoder = autoencoder_analysis(autoencoder, unsupervised_dataset)   

    #IPython.embed()


    ##### PLOT AUTOENCODER LOSS LIST ####
    plt.title("Autoencoder")
    plt.xlabel("Num Batches")
    plt.ylabel("Reconstruction Loss")

    plt.plot(logging_frequency*np.arange(len(test_loss_list_autoencoder)), test_loss_list_autoencoder, linewidth = 3, color = "red", label = "Batch Size {}".format(autoencoder_train_batch_size))
    plt.legend(loc = "upper right")

    plt.savefig("./autoencoder_results/reconstruction_loss.png")

    plt.close('all')
    #IPython.embed()


    ###### RUN REGRESSION ANALYSIS ##### 
    base_supervised_dataset = DataSet(raw_dataset_df, responses_Y)
    encoded_supervised_dataset = get_autoencoder_dataset(base_supervised_dataset, autoencoder)

    pickle.dump( encoded_supervised_dataset,  open("../CMAP/cmap_processed/encoded_trt_sh.VCAP.978genes.p", "wb"))


    num_steps = 5000 #12000
    batch_size = 200
    logging_frequency = 100

    dim = encoded_supervised_dataset.dataset.values.shape[1]
    #IPython.embed()


    regression_analysis(dim, encoded_supervised_dataset, num_steps, batch_size, logging_frequency, responses_Y, data_index= "", representation_layer_sizes = [1500, 300, 100, 10])



    ### TRASNFER ANALYSIS
    autoencoded_df = pd.DataFrame(copy.deepcopy(encoded_supervised_dataset.dataset))
    autoencoded_df['cid'] = list(df.index)
    autoencoded_df = autoencoded_df.set_index('cid')
    autoencoded_df1 = pd.DataFrame(autoencoded_df.loc[df1.index].values)
    autoencoded_df2 = pd.DataFrame(autoencoded_df.loc[df2.index].values)
    autoencoded_df3 = pd.DataFrame(autoencoded_df.loc[df3.index].values)



    encoded_supervised_dataset1 = DataSet(autoencoded_df1, responses_Y1)
    encoded_supervised_dataset2 = DataSet(autoencoded_df2, responses_Y2)
    encoded_supervised_dataset3 = DataSet(autoencoded_df3, responses_Y3)



    pickle.dump( encoded_supervised_dataset1, open("../CMAP/cmap_processed/encoded_trt_sh.HA1E.978genes.p", "wb") )
    pickle.dump( encoded_supervised_dataset2, open("../CMAP/cmap_processed/encoded_trt_sh.MCF7.978genes.p", "wb"))
    pickle.dump( encoded_supervised_dataset3, open("../CMAP/cmap_processed/encoded_trt_sh.A375.978genes.p", "wb") )



    regression_analysis(dim, encoded_supervised_dataset1, num_steps, batch_size, logging_frequency, responses_Y1, data_index= "1", representation_layer_sizes = [1500, 300, 100, 10])
    regression_analysis(dim, encoded_supervised_dataset2, num_steps, batch_size, logging_frequency, responses_Y2, data_index= "2", representation_layer_sizes = [1500, 300, 100, 10])
    regression_analysis(dim, encoded_supervised_dataset3, num_steps, batch_size, logging_frequency, responses_Y3, data_index= "3", representation_layer_sizes = [1500, 300, 100, 10])


    IPython.embed()




if __name__ == "__main__":
    main()