import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt


from matplotlib import cm
import os
import IPython
import zipfile
import pickle


def get_colors(num_colors):
  #num_colors = len(model_parameter_configurations) + 1
  color_numbers = np.linspace(0,1, num_colors)
  color_map = cm.get_cmap('viridis', num_colors)
  colors  = [color_map(color_numbers[i]) for i in range(num_colors)]
  return colors


def plot_ranks(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    #IPython.embed()
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])
    plt.title("{} {}".format(dataset_name, title_stub))
    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Ranks")

    if upper_y_lim != None:
      plt.ylim([0, upper_y_lim])

    for results_name, i in zip(results_dictionary.keys(), range(len(results_dictionary.keys()))):
      #print(i, " ", results_dictionary[results_name][0])
      plt.plot(np.arange(true_results_size)+1, results_dictionary[results_name][0], linewidth = 3, color = colors[i], label = results_name)
      plt.fill_between(np.arange(true_results_size)+1, results_dictionary[results_name][0] - .5*results_dictionary[results_name][1], 
        results_dictionary[results_name][0] + .5*results_dictionary[results_name][1], color = colors[i], alpha = .2)

    plt.legend(loc = "upper right")

    plt.savefig(filename)

    plt.close("all")



def plot_rewards(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None, lower_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])

    plt.title("{} {}".format(dataset_name, title_stub))

    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Max Rewards")
    if upper_y_lim == None and lower_y_lim != None:
      raise ValueError("upper y lim is none but lower y lim is not")
    if upper_y_lim != None and lower_y_lim == None:
      raise ValueError("upper y lim is not none but lower y lim is")

    if upper_y_lim != None:
      plt.ylim([lower_y_lim,upper_y_lim])

    #IPython.embed()
    for results_name, i in zip(results_dictionary.keys(), range(len(results_dictionary.keys()))):
      plt.plot(np.arange(true_results_size)+1, results_dictionary[results_name][2], linewidth = 3, color = colors[i], label = results_name)
      plt.fill_between(np.arange(true_results_size)+1, results_dictionary[results_name][2] - .5*results_dictionary[results_name][3], 
        results_dictionary[results_name][2] + .5*results_dictionary[results_name][3], color = colors[i], alpha = .2)
    plt.legend(loc = "lower right")
    plt.savefig(filename)

    plt.close("all")



def plot_quantiles(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None, lower_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])

    plt.title("{} {}".format(dataset_name, title_stub))

    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Quantiles (%)")
    if upper_y_lim == None and lower_y_lim != None:
      raise ValueError("upper y lim is none but lower y lim is not")
    if upper_y_lim != None and lower_y_lim == None:
      raise ValueError("upper y lim is not none but lower y lim is")

    if upper_y_lim != None and lower_y_lim != None:
      plt.ylim([lower_y_lim,upper_y_lim])




    #IPython.embed()
    for results_name, i in zip(results_dictionary.keys(), range(len(results_dictionary.keys()))):
      quantile_means = 100*results_dictionary[results_name][0]
      quantile_stds = 100*results_dictionary[results_name][1]
      plt.plot(np.arange(true_results_size)+1, quantile_means, linewidth = 3, color = colors[i], 
        label = results_name)
      plt.fill_between(np.arange(true_results_size)+1, quantile_means - .5*quantile_stds, 
        quantile_means+ .5*quantile_stds, color = colors[i], alpha = .2)
    plt.legend(loc = "upper right")
    plt.savefig(filename)

    plt.close("all")



def plot_quantiles_bar(dataset_name, title_stub, results_dictionary, filename, quantile_probes = [1, 5, 10]):

    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])

    plt.title("{} {}".format(dataset_name, title_stub))

    plt.xlabel("Quantiles (%)")
    plt.ylabel("Required Batches")

    data = []

    algos = list(results_dictionary.keys())
    num_algos = len(algos)*1.0

    ### Change formatting for algos list.
    new_names = []
    for algo in algos:
      split_algo_name = algo.split("-")
      #print("Split algo name ", split_algo_name)
      if split_algo_name[0] == "RandomBatch":
        new_names.append("RandomBatch")
      elif split_algo_name[0] == "Greedy":
        new_names.append("Greedy")
      elif split_algo_name[0] == "Lambda":
        new_names.append("Lambda-{}".format(split_algo_name[1]))
      else:
        raise ValueError("Strange Name encountered {}".format(algo))

    algos = new_names

    for results_name, i in zip(results_dictionary.keys(), range(len(results_dictionary.keys()))):
      quantile_means = 100*np.array(results_dictionary[results_name][0])
      quantile_stds = 100*np.array(results_dictionary[results_name][1])
      quantiles_experiments = results_dictionary[results_name][2]
      num_experiments = len(quantiles_experiments)
      #IPython.embed()
      #raise ValueError("asdflkm")

      quantile_probe_results = []

      for quantile_probe in quantile_probes:  
        num_batches_needed_for_quantile = []
        for i in range(num_experiments):
            quantile_experiment = 100*np.array(quantiles_experiments[i])
            #IPython.embed()
            #raise ValueError("Asdfklm")
            result = np.sum(quantile_experiment >= quantile_probe)
            num_batches_needed_for_quantile.append(result)
        quantile_probe_results.append((np.mean(num_batches_needed_for_quantile), np.std(num_batches_needed_for_quantile)))

      data.append(quantile_probe_results)

    X = np.arange(len(quantile_probes))
    for i in range(int(num_algos)):
      #IPython.embed()
      plt.bar( X + i/(num_algos+1), [x for (x,y) in data[i]], color = colors[i], yerr = [.2*y for (x,y) in data[i]],
        width= 1/(num_algos +1) , label =algos[i]  )

    #IPython.embed()

    #fig, ax = plt.subplots()
    plt.xticks(X +num_algos/(2*(num_algos+1)), quantile_probes)
    plt.legend(loc = "upper left")
    #plt.legend(loc = (1.04, 0))
    plt.savefig(filename)

    plt.close("all")

    #IPython.embed()  



def get_plot_directory_stub(results_parent_directory, dataset_name, batch_size, num_batches, num_experiments):
    plot_directory_stub = "{}/T{}/B{}".format(dataset_name, num_batches, batch_size)
    results = "{}/{}".format(results_parent_directory, plot_directory_stub)
    return results



def get_multi_experiment_name_stub(dataset_name, batch_size, algorithm_types, l1, num_batches, representation_layer_sizes, num_experiments):
    layer_sizes_stub = "hidden"
    for layer_size in representation_layer_sizes:
      layer_sizes_stub += "_" + str(layer_size)



    algorithms_stub = algorithm_types[0]

    for i in range(1, len(algorithm_types)):
      algorithms_stub += "_" + algorithm_types[i]


      
    name_stub = "{}_batchsize_{}_{}_T{}_{}_E{}".format(dataset_name, batch_size, algorithms_stub, num_batches, layer_sizes_stub, num_experiments)
    if l1:
      name_stub = "{}_batchsize_{}_l1_{}_{}_T{}_{}_E{}".format(dataset_name, batch_size, l1, algorithms_stub, num_batches, layer_sizes_stub, num_experiments)

    return name_stub





def get_experiment_name_stub(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, random_seed):
    layer_sizes_stub = "hidden"
    for layer_size in representation_layer_sizes:
      layer_sizes_stub += "_" + str(layer_size)

    if algorithm_type == "RandomBatch":
      name_stub = "{}_batchsize_{}_{}_T{}_E{}_RS{}".format(dataset_name, batch_size, algorithm_type, num_batches, num_experiments, random_seed)
     
    else:
      name_stub = "{}_batchsize_{}_{}_T{}_{}_E{}_RS{}".format(dataset_name, batch_size, algorithm_type, num_batches, layer_sizes_stub, num_experiments, random_seed)
      if l1:
        name_stub = "{}_batchsize_{}_l1_{}_{}_T{}_{}_E{}_RS{}".format(dataset_name, batch_size, l1, algorithm_type, num_batches, layer_sizes_stub, num_experiments, random_seed)

    return name_stub



# def get_experiment_name_stub_old(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes):
#     layer_sizes_stub = "hidden"
#     for layer_size in representation_layer_sizes:
#       layer_sizes_stub += "_" + str(layer_size)

#     name_stub = "{}_batchsize_{}_{}_T{}_{}".format(dataset_name, batch_size, algorithm_type, num_batches, layer_sizes_stub)
#     if l1:
#       name_stub = "{}_batchsize_{}_l1_{}_{}_T{}_{}".format(dataset_name, batch_size, l1, algorithm_type, num_batches, layer_sizes_stub)

#     return name_stub




def log_experiment_data(results_dictionary, results_filename_stub, base_data_dir, is_zip_file = False):

  pickle_results_filename = "{}.p".format(results_filename_stub)
  ### start by saving the file using pickle

  pickle.dump( results_dictionary, 
    open("{}/{}".format(base_data_dir, pickle_results_filename), "wb"))

  if is_zip_file:

    zip_results_filename = "{}.zip".format(results_filename_stub)
    zip_file = zipfile.ZipFile("{}/{}".format(base_data_dir, zip_results_filename), 'w')

    zip_file.write("{}/{}".format(base_data_dir, pickle_results_filename), compress_type = zipfile.ZIP_DEFLATED, 
      arcname = os.path.basename("{}/{}".format(base_data_dir, pickle_results_filename)) )
    
    zip_file.close()

    os.remove("{}/{}".format(base_data_dir, pickle_results_filename))


def load_experiment_data(base_data_dir, dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, random_seed, is_zip_file = False):
  
  results_filename_stub = get_experiment_name_stub(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, random_seed)

  pickle_results_filename = "{}.p".format(results_filename_stub)

  ## If it is a ZIP file extract the pickle file.
  if is_zip_file:
    zip_results_filename = "{}.zip".format(results_filename_stub)
    
    
    zip_file = zipfile.ZipFile("{}/{}".format(base_data_dir, zip_results_filename), "r")
    zip_file.extractall(base_data_dir)

  results_dictionary = pickle.load( open("{}/{}".format(base_data_dir, pickle_results_filename), "rb") )

  ## If it is a ZIP file, delete the pickle file.
  if is_zip_file:
    os.remove("{}/{}".format(base_data_dir, pickle_results_filename))




  return results_dictionary






