import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt

#from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 

#from matplotlib import cm
#from matplotlib.ticker import LinearLocator

import pandas as pd

import torch
from onlinedatasets.models import TorchRewardsModel, TorchRewardsModelMultilayer
from onlinedatasets.datasets import get_dataset
#from algorithms import *
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from matplotlib import cm
from plotting_paper_tools import plot_ranks_simple_names, plot_rewards_simple_names, plot_quantiles_simple_names, plot_quantiles_bar_simple_names
from plotting_tools import get_experiment_name_stub, load_experiment_data, get_multi_experiment_name_stub , get_plot_directory_stub
from datasets_CMAP import get_dataset_CMAP
from run_experiments import get_experiments_dataset
import os

import itertools
import IPython
import random

import sys



def get_quantiles(dataset, max_observed_rewards):
  list_y_values = list(dataset.labels.values.squeeze())
  list_y_values.sort()  
  list_y_values.reverse()
  dataset_size = len(list_y_values)

  list_y_values = np.array(list_y_values)
  quantiles = []
  for max_observed_reward in max_observed_rewards:
    distances = np.abs( list_y_values - max_observed_reward )
    rank = np.argmin(distances) + 1
    quantile = rank*1.0/dataset_size
    quantiles.append(quantile)
  return quantiles





def plot_validation(results_dictionary):
  algos = []
  raw_results = []
  colors = get_colors(len(results_dictionary.keys()))

  for algo_name in results_dictionary.keys():
    if algo_name != "RandomBatch":
      algos.append(algo_name)
      raw_results = results_dictionary[algo_name][-1]

  for i in range(len(algos)):
    num_experiments = len(raw_results[i])
    validation_losses = []
    for j in range(num_experiments):
      validation_losses.append(raw_results[i][j][3])

    mean_validation = np.mean(validation_losses, axis = 0)
    std_validation = np.std(validation_losses, axis = 0)

    plt.plot(np.arange(raw_results[i][j][3])+1, mean_validation, linewidth = 3, color = colors[i], label = simplified_name)
    plt.fill_between(np.arange(raw_results[i][j][3])+1, mean_validation - .5*std_validation, 
        mean_validation + .5*std_validation, color = colors[i], alpha = .2)

    plt.savefig('./paper_results/test_validation.png')

    plt.close("all")



def produce_paper_plots(dataset_name, batch_size, num_batches, algorithm_types, representation_layer_sizes, l1, num_experiments, random_seed):


  dataset_info = dict([])
  dataset_info["noisy"] = False
  #IPython.embed()
  torch.manual_seed(random_seed)
  random.seed(random_seed)
  np.random.seed(random_seed)
  torch.cuda.manual_seed(random_seed)
  torch.cuda.manual_seed_all(random_seed)



  dataset = get_experiments_dataset(dataset_name, dataset_info, representation_layer_sizes)


  path = os.getcwd()

  data_directory = "{}/results".format(path)


  results_dictionaries_items = []

  for algorithm_type in algorithm_types:
    results_filename_stub = get_experiment_name_stub(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, random_seed)
    zip_results_filename = "{}/{}.zip".format(data_directory, results_filename_stub)


    # if algorithm_type == "RandomBatch":
    #   IPython.embed()
    #   raise ValueError("aslfdkm")


    if not os.path.exists(zip_results_filename):
      print("Data for experiment {} is not in yet.".format(results_filename_stub))
      return
    else:
      print("Data for experiment {} found.".format(results_filename_stub))


    partial_results_dictionary = load_experiment_data(data_directory, dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, random_seed, is_zip_file = True)
    results_dictionaries_items += list(partial_results_dictionary.items())

    #IPython.embed()
    #raise ValueError("asldfkm")

  name_stub = get_multi_experiment_name_stub(dataset_name, batch_size, algorithm_types, l1, num_batches, representation_layer_sizes, num_experiments)

  print("Plotting {}".format(name_stub))


  results_dictionary = dict(results_dictionaries_items)

  #IPython.embed()

  quantile_experiments_dictionary = dict([])

  #IPython.embed()

  for key in results_dictionary.keys():
    results = results_dictionary[key][-1]
    num_experiments = len(results)
    #print("num experiments ", num_experiments)

    quantiles_experiments = []
    for i in range(num_experiments):
      #IPython.embed()

      if len(results[i]) == 3:
        (ranks, max_observed_rewards, true_max_reward ) = results[i]
      elif len(results[i]) == 4:
        (ranks, max_observed_rewards, true_max_reward, _ ) = results[i]
      elif len(results[i]) == 5:
        (ranks, max_observed_rewards, true_max_reward, _, _ ) = results[i] 
      else:
        raise ValueError("results size is unrecognized {}.".format(len(results[i])))

      #print(ranks, max_observed_rewards, true_max_reward)

      quantiles = get_quantiles(dataset, max_observed_rewards)

      quantiles_experiments.append(quantiles)

    quantiles_mean = np.mean(quantiles_experiments, 0)
    quantiles_std  = np.std(quantiles_experiments, 0)

    quantile_experiments_dictionary[key] = (quantiles_mean, quantiles_std, quantiles_experiments)
    #IPython.embed()

  #plot_validation(results_dictionary)


  #IPython.embed()
  results_parent_directory = "{}/paper_results".format(path)
  results_directory = get_plot_directory_stub(results_parent_directory, dataset_name, batch_size, num_batches, num_experiments)

  #IPython.embed()
  # raise ValueError("asdflkm")

  if not os.path.isdir(results_directory):
      try:
          os.makedirs(results_directory)
      except OSError:
          print ("Creation of the figs directories failed")
      else:
          print ("Successfully created the figs directory ")



  #results_directory = "{}/{}".format(results_parent_directory, )


  title_stub = "Quantiles"
  filename = "{}/quantiles_{}.png".format(results_directory, name_stub )
  filename_bar = "{}/quantiles_bar_{}.png".format(results_directory, name_stub )

  plot_quantiles_bar_simple_names(dataset_name, title_stub, quantile_experiments_dictionary, filename_bar, quantile_probes = [.8, .5, .2])



  plot_quantiles_simple_names(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = None, lower_y_lim = None)
  filename = "{}/quantiles_closeup_{}.png".format(results_directory, name_stub )

  plot_quantiles_simple_names(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = 10, lower_y_lim = 0)

  filename = "{}/quantiles_supercloseup_{}.png".format(results_directory, name_stub )

  plot_quantiles_simple_names(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = .5, lower_y_lim = 0)


  filename = "{}/rewards_{}.png".format(results_directory, name_stub )
  title_stub = "Rewards"
  plot_rewards_simple_names(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None, lower_y_lim = None)



  print("Finished plotting {}".format(name_stub))



num_experiments = 25
l1 = False
random_seed = 10823
#representation_layer_sizes = [10, 5]#[100,10]

# torch.manual_seed(random_seed)
# random.seed(random_seed)
# np.random.seed(random_seed)
# torch.cuda.manual_seed(random_seed)
# torch.cuda.manual_seed_all(random_seed)


#IPython.embed()

bio_datasets = ["VCAP", "HA1E", "MCF7",  "A375"]
uci_datasets = ["BlogFeedback", "BikeSharingDay", "BikeSharingHour" ]
regression_fitted_datasets = ["BikeSharingDayRegression", "AdultRegression"]# ["BikeSharingHourRegression", "AdultRegression","BikeSharingDayRegression", "BankRegression"]
simulated_datasets = ["OneHillValleyOneDim",'MultiValleyOneDim', "MultiValleyOneDimHole", "OneHillValleyOneDimHole"]

algorithm_types_list_nodiv = [["MeanOptimism"],  [ "RandomBatch", "MeanOptimism" ]]#, [ "RandomBatch", "HingePNormOptimism" ], ["RandomBatch", "MaxOptimism" ], ["EnsembleOptimism", "EnsembleOptimismNoiseY", "MeanOptimism"] ]
# algorithm_types_list_nodiv += [["RandomBatch", "MeanOptimism", "EnsembleOptimism" ], [ "EnsembleOptimism", "MeanOptimism" ], ["EnsembleOptimismNoiseY", "MeanOptimism"]]
# algorithm_types_list_nodiv += [["EnsembleOptimism", "EnsembleOptimismNoiseY", "MeanOptimism"] ]
# algorithm_types_list_nodiv += [ ["EnsembleOptimismNoiseY", "MeanOptimism"], ["EnsembleOptimism", "EnsembleOptimismNoiseY", "MeanOptimism"] ]

#algorithm_types_list_div = []
algorithm_types_list_div = [ [ "RandomBatch", "DeterminantsDiversityOptimism" ], [ "MeanOptimism", "SequentialBatchOptimism" ], ["EnsembleOptimism", "EnsembleSequentialBatchOptimism"], ["EnsembleOptimismNoiseY", "EnsembleSequentialBatchOptimismNoiseY"] ]
# algorithm_types_list_div += [[ "MeanOptimism", "SequentialBatchOptimism" ], ["SequentialBatchOptimism" , "EnsembleSequentialBatchOptimism" ], [ "MeanOptimism", "EnsembleSequentialBatchOptimism" ] ]
# algorithm_types_list_div += [[ "MeanOptimism", "EnsembleSequentialBatchOptimismNoiseY" ], ["SequentialBatchOptimism" , "EnsembleSequentialBatchOptimismNoiseY" ]]
# algorithm_types_list_div += [["EnsembleOptimism", "EnsembleSequentialBatchOptimism"], ["EnsembleOptimismNoiseY", "EnsembleSequentialBatchOptimismNoiseY"]]

#algorithm_types_list_div += [["SequentialBatchOptimism" , "EnsembleSequentialBatchOptimism" ], [ "MeanOptimism", "EnsembleSequentialBatchOptimism" ] ]


#algorithm_types_list = [['RandomBatch', 'MeanOptimism', 'EnsembleOptimism'], ['RandomBatch', "SequentialBatchOptimism" ], ['RandomBatch', "DeterminantsDiversityOptimism" ], ["MeanOptimism", "SequentialBatchOptimism"]]

algorithm_types_list = algorithm_types_list_div + algorithm_types_list_nodiv


for algorithm_types in algorithm_types_list:#algorithm_types_list_div + algorithm_types_list_nodiv:
  for representation_layer_sizes in [[300,100]]:#[[10,5], [100,10]]:

    batch_size = 3
    num_batches = 150

    for dataset_name in simulated_datasets:# + uci_datasets:# simulated_datasets + uci_datasets:# + regression_fitted_datasets + simulated_datasets:
        produce_paper_plots(dataset_name, batch_size, num_batches, algorithm_types, representation_layer_sizes, l1, num_experiments, random_seed)

    # batch_size = 50
    # num_batches = 20

    # for dataset_name in bio_datasets:# bio_datasets:# + uci_datasets + regression_fitted_datasets + simulated_datasets: 
    #     produce_paper_plots(dataset_name, batch_size, num_batches, algorithm_types, representation_layer_sizes, l1, num_experiments, random_seed)

    # batch_size = 20
    # num_batches = 40

    # for dataset_name in uci_datasets + regression_fitted_datasets + simulated_datasets:
    #     produce_paper_plots(dataset_name, batch_size, num_batches, algorithm_types, representation_layer_sizes, l1, num_experiments, random_seed)




