import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt


from matplotlib import cm
import os
import IPython
import zipfile
import pickle

from plotting_tools import get_colors


def get_new_names(algos):
    new_names = []
    for algo in algos:
      split_algo_name = algo.split("-")
      #print("Split algo name ", split_algo_name)
      if split_algo_name[0] == "RandomBatch":
        new_names.append("RandomBatch")
      elif split_algo_name[0] == "Greedy":
        post = ""
        if split_algo_name[-1] == "DeterminantsDiversityOptimism":
          post  = " DetD"
        new_names.append("Greedy{}".format(post))

      elif split_algo_name[0] == "Lambda":
        
        if split_algo_name[-1] == "SequentialBatchOptimism":
          post = "SeqB"
        elif split_algo_name[-1] == "DeterminantsDiversityOptimism":
          post = "DetD"
        elif split_algo_name[-1] == "MeanOptimism":
          post = "MeanOpt"
        elif split_algo_name[-1] == "MaxOptimism":
          post = "MaxOpt"
        elif split_algo_name[-1] == "HingePNormOptimism":
          post = "HingeOpt"
        else:
          raise ValueError("Strange name encountered for post {}".format(split_algo_name[-1]))
        new_names.append("Lambda-{} {}".format(split_algo_name[1], post))
      elif split_algo_name[0] == "EnsembleOptimism":
        new_names.append("Ensemble")      
      elif split_algo_name[0] == "EnsembleSequentialBatchOptimism":
        new_names.append("Ensemble-SeqB")
      elif split_algo_name[0] == "EnsembleSequentialBatchOptimismNoiseY":
        new_names.append("Ensemble-SeqB-NoiseY")        
      elif split_algo_name[0] == "EnsembleOptimismNoiseY":
        new_names.append("Ensemble-NoiseY")
      else:
        raise ValueError("Strange Name encountered {}".format(algo))

    return new_names




def plot_ranks_simple_names(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    #IPython.embed()
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])
    plt.title("{} {}".format(dataset_name, title_stub))
    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Ranks")

    if upper_y_lim != None:
      plt.ylim([0, upper_y_lim])

    algos =  results_dictionary.keys()
    algos = get_new_names(algos)


    for results_name, simplified_name, i in zip(results_dictionary.keys(), algos, range(len(results_dictionary.keys()))):
      #print(i, " ", results_dictionary[results_name][0])
      plt.plot(np.arange(true_results_size)+1, results_dictionary[results_name][0], linewidth = 3, color = colors[i], label = simplified_name)
      plt.fill_between(np.arange(true_results_size)+1, results_dictionary[results_name][0] - .5*results_dictionary[results_name][1], 
        results_dictionary[results_name][0] + .5*results_dictionary[results_name][1], color = colors[i], alpha = .2)

    plt.legend(loc = "upper right")

    plt.savefig(filename)

    plt.close("all")



def plot_rewards_simple_names(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None, lower_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])

    plt.title("{} {}".format(dataset_name, title_stub))

    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Max Rewards")
    if upper_y_lim == None and lower_y_lim != None:
      raise ValueError("upper y lim is none but lower y lim is not")
    if upper_y_lim != None and lower_y_lim == None:
      raise ValueError("upper y lim is not none but lower y lim is")

    if upper_y_lim != None:
      plt.ylim([lower_y_lim,upper_y_lim])

    algos =  results_dictionary.keys()
    algos = get_new_names(algos)



    #IPython.embed()
    for results_name, simplified_name, i in zip(results_dictionary.keys(), algos, range(len(results_dictionary.keys()))):
      plt.plot(np.arange(true_results_size)+1, results_dictionary[results_name][2], linewidth = 3, color = colors[i], label = simplified_name)
      plt.fill_between(np.arange(true_results_size)+1, results_dictionary[results_name][2] - .5*results_dictionary[results_name][3], 
        results_dictionary[results_name][2] + .5*results_dictionary[results_name][3], color = colors[i], alpha = .2)
    #plt.ylim([4,6])
    plt.legend(loc = "lower right")
    plt.savefig(filename)

    plt.close("all")



def plot_quantiles_simple_names(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None, lower_y_lim = None):
    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])

    plt.title("{} {}".format(dataset_name, title_stub))

    plt.xlabel("Num Batches")
    plt.ylabel("Resulting Quantiles (%)")
    if upper_y_lim == None and lower_y_lim != None:
      raise ValueError("upper y lim is none but lower y lim is not")
    if upper_y_lim != None and lower_y_lim == None:
      raise ValueError("upper y lim is not none but lower y lim is")

    if upper_y_lim != None and lower_y_lim != None:
      plt.ylim([lower_y_lim,upper_y_lim])



    algos =  results_dictionary.keys()
    algos = get_new_names(algos)
    #IPython.embed()
    for results_name, simple_name, i in zip(results_dictionary.keys(), algos, range(len(results_dictionary.keys()))):
      quantile_means = 100*results_dictionary[results_name][0]
      quantile_stds = 100*results_dictionary[results_name][1]
      plt.plot(np.arange(true_results_size)+1, quantile_means, linewidth = 3, color = colors[i], 
        label = simple_name)
      plt.fill_between(np.arange(true_results_size)+1, quantile_means - .5*quantile_stds, 
        quantile_means+ .5*quantile_stds, color = colors[i], alpha = .2)
    plt.legend(loc = "upper right")
    plt.savefig(filename)

    plt.close("all")



def plot_quantiles_bar_simple_names(dataset_name, title_stub, results_dictionary, filename, quantile_probes = [1, 5, 10]):

    results_dictionary_keys = list(results_dictionary.keys())
    colors = get_colors(len(results_dictionary_keys))
    true_results_size = len(results_dictionary[results_dictionary_keys[0]][0])
    plt.rcParams.update({'font.size': 14})

    plt.title("{} {}".format(dataset_name, title_stub))
    plt.xlabel("Quantiles (%)")
    plt.ylabel("Required Batches")

    data = []

    algos = list(results_dictionary.keys())
    num_algos = len(algos)*1.0

    ### Change formatting for algos list.
    # new_names = []
    # for algo in algos:
    #   split_algo_name = algo.split("-")
    #   #print("Split algo name ", split_algo_name)
    #   if split_algo_name[0] == "RandomBatch":
    #     new_names.append("RandomBatch")
    #   elif split_algo_name[0] == "Greedy":
    #     new_names.append("Greedy")
    #   elif split_algo_name[0] == "Lambda":
    #     new_names.append("Lambda-{}".format(split_algo_name[1]))
    #   elif split_algo_name[0] == "EnsembleOptimism":
    #     new_names.append("Ensemble")
    #   else:

    #     raise ValueError("Strange Name encountered {}".format(algo))

    algos = get_new_names(algos)

    for results_name, i in zip(results_dictionary.keys(), range(len(results_dictionary.keys()))):
      #quantile_means = 100*np.array(results_dictionary[results_name][0])
      #quantile_stds = 100*np.array(results_dictionary[results_name][1])
      quantiles_experiments = results_dictionary[results_name][2]
      num_experiments = len(quantiles_experiments)
      #IPython.embed()
      #raise ValueError("asdflkm")

      quantile_probe_results = []

      for quantile_probe in quantile_probes:  
        num_batches_needed_for_quantile = []
        for i in range(num_experiments):
            quantile_experiment = 100*np.array(quantiles_experiments[i])
            result = np.sum(quantile_experiment > quantile_probe) + 1 # The +1 allows us to count the number of batches until it first entered that quantile not before it entered it. 
            #IPython.embed()
            #raise ValueError("Asdfklm")

            num_batches_needed_for_quantile.append(result)
        quantile_probe_results.append((np.mean(num_batches_needed_for_quantile), np.std(num_batches_needed_for_quantile)))

      data.append(quantile_probe_results)

    X = np.arange(len(quantile_probes))
    for i in range(int(num_algos)):
      #IPython.embed()
      plt.bar( X + i/(num_algos+1), [x for (x,y) in data[i]], color = colors[i], yerr = [.2*y for (x,y) in data[i]],
        width= 1/(num_algos +1) , label =algos[i]  )

    #IPython.embed()

    #fig, ax = plt.subplots()
    plt.xticks(X +num_algos/(2*(num_algos+1)), quantile_probes)
    #plt.ylim([0,25])
    plt.legend(loc = "upper left", fontsize = 11)
    #plt.legend(loc = (1.04, 0))
    plt.savefig(filename)
    # IPython.embed()
    # raise ValueError("asdlfkm")

    plt.close("all")

    #IPython.embed()  






