import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt

#from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 

#from matplotlib import cm
#from matplotlib.ticker import LinearLocator

import pandas as pd

import torch
from onlinedatasets.models import TorchRewardsModel, TorchRewardsModelMultilayer
from onlinedatasets.datasets import get_dataset
#from algorithms import *
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from matplotlib import cm
from plotting_tools import plot_ranks, plot_rewards, get_experiment_name_stub, log_experiment_data, load_experiment_data, plot_quantiles, plot_quantiles_bar
from datasets_CMAP import get_dataset_CMAP
from run_experiments import get_experiments_dataset
import os

import itertools
import IPython

import ray

def get_quantiles(dataset, max_observed_rewards):
  list_y_values = list(dataset.labels.values.squeeze())
  list_y_values.sort()  
  list_y_values.reverse()
  dataset_size = len(list_y_values)

  list_y_values = np.array(list_y_values)
  quantiles = []
  for max_observed_reward in max_observed_rewards:
    distances = np.abs( list_y_values - max_observed_reward )
    rank = np.argmin(distances) + 1
    quantile = rank*1.0/dataset_size
    quantiles.append(quantile)
  return quantiles




dataset_name = "A375"
batch_size = 50
algorithm_type = 'MeanOptimism' 
l1 = False
num_batches = 2 
representation_layer_sizes = [100, 10]
name_stub = get_experiment_name_stub_old(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes)


dataset_info = dict([])
dataset_info["noisy"] = False
#IPython.embed()
dataset = get_experiments_dataset(dataset_name, dataset_info, representation_layer_sizes)


path = os.getcwd()

results_directory = "{}/results".format(path)

results_dictionary = load_experiment_data(results_directory, dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments,  is_zip_file = True)

quantile_experiments_dictionary = dict([])

for key in results_dictionary.keys():
	results = results_dictionary[key][-1]
	num_experiments = len(results)
	print("num experiments ", num_experiments)

	quantiles_experiments = []
	for i in range(num_experiments):
		(ranks, max_observed_rewards, true_max_reward ) = results[i]

		print(ranks, max_observed_rewards, true_max_reward)

		quantiles = get_quantiles(dataset, max_observed_rewards)

		quantiles_experiments.append(quantiles)

	quantiles_mean = np.mean(quantiles_experiments, 0)
	quantiles_std  = np.std(quantiles_experiments, 0)

	quantile_experiments_dictionary[key] = (quantiles_mean, quantiles_std, quantiles_experiments)


#IPython.embed()

title_stub = "Quantiles"
filename = "{}/quantiles_{}.png".format(results_directory, name_stub )
filename_bar = "{}/quantiles_bar_{}.png".format(results_directory, name_stub )

plot_quantiles_bar(dataset_name, title_stub, quantile_experiments_dictionary, filename_bar, quantile_probes = [.8, .5, .2])





plot_quantiles(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = None, lower_y_lim = None)
filename = "{}/quantiles_closeup_{}.png".format(results_directory, name_stub )

plot_quantiles(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = 10, lower_y_lim = 0)

filename = "{}/quantiles_supercloseup_{}.png".format(results_directory, name_stub )

plot_quantiles(dataset_name, title_stub, quantile_experiments_dictionary, filename, upper_y_lim = .5, lower_y_lim = 0)

