import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt

#from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 

#from matplotlib import cm
#from matplotlib.ticker import LinearLocator

import pandas as pd

import torch
from onlinedatasets.models import TorchRewardsModel, TorchRewardsModelMultilayer
from onlinedatasets.datasets import get_dataset
#from algorithms import *
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from algorithms import get_max_reward_batch, evaluate_rank_observed_reward, get_random_reward_batch, DataSetWithRegressionResponses, RandomBatch, MeanOptimism, MaxOptimism, HingePNormOptimism, EnsembleOptimism, DeterminantOptimism, SequentialBatchOptimism, EnsembleSequentialBatchOptimism

from matplotlib import cm
from plotting_tools import plot_ranks, plot_rewards, get_experiment_name_stub, log_experiment_data
from datasets_CMAP import get_dataset_CMAP

import os

import itertools
import IPython

import ray

import sys

import random

#RANDOM_SEED = 10823

USE_RAY = True
ZIP_FILE = True
PLOT_PRELIMINARY = False






def get_experiments_dataset(dataset_name, dataset_info = dict([]), representation_layer_sizes = [5]):

  if dataset_name in ["VCAP", "HA1E", "MCF7",  "A375"]:
    train_dataset, unsupervised_dataset = get_dataset_CMAP(dataset_name, return_dataframe = True)

  elif dataset_name[-10:]=="Regression":
    (
        _,
        _,
        train_dataset,
        _,
        unsupervised_dataset,
    ) = get_dataset(dataset_name[:-10], info = dataset_info)

    print("Fitting dataset with regression responses.")

    train_dataset = DataSetWithRegressionResponses(train_dataset, MLP= True, 
      representation_layer_sizes = representation_layer_sizes, 
      num_steps = 10000, 
      batch_size = 20)



    print("Finished fitting dataset with regression responses.")

    #IPython.embed()
  else:

    (
        _,
        _,
        train_dataset,
        _,
        unsupervised_dataset,
    ) = get_dataset(dataset_name, info = dataset_info)

  all_data_X, all_data_y = get_batches( train_dataset, 100000) 
  dataset = DataSet(pd.DataFrame(all_data_X), pd.DataFrame(all_data_y))
  if dataset_name in ["VCAP", "HA1E", "MCF7",  "A375"]:
    dataset.set_return_dataframe(True)

  dataset.set_name(dataset_name)

  return dataset






def train_optimistic_reward_one_dim(algorithm, dataset, num_batches, num_opt_steps, 
    batch_size, opt_batch_size = 10, record_validation_loss = False):

    

    extra_info = []

    print("Algorithm type ", algorithm.name)

    all_data_X, all_data_y = get_batches(dataset, 100000) 
    #IPython.embed()

    true_max_reward = np.max(all_data_y)

    complement_batch_X = np.copy(all_data_X)
    complement_batch_y = np.copy(all_data_y)


    complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y) )
    # IPython.embed()
    # raise ValueError("asdlkfm")

    if dataset.return_dataframe:
      complement_supervised_dataset.set_return_dataframe(True)
      complement_supervised_dataset.set_name(dataset.name)

    growing_dataset = GrowingNumpyDataSet()
    max_observed_reward_during_training = -float("inf")
    max_observed_rewards = []
    model_fitting_losses = []

    validation_losses = []

    # IPython.embed()    
    # raise ValueError("asdflkm")



    for j in range(num_batches):
        #optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
  

        filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y = algorithm.get_batch(growing_dataset, complement_supervised_dataset, batch_size)
        #IPython.embed()
        if dataset.return_dataframe:
          extra_info += [(list(filtered_batch_X.index),)]

        
        growing_dataset.add_data(np.array(filtered_batch_X), np.array(filtered_batch_y))

        complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  
        complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y), 
          return_dataframe = dataset.return_dataframe)

        print("Len filtered batch y ", len(filtered_batch_y), " ", algorithm.get_name(), " ", dataset.name)
        if len(filtered_batch_y) < batch_size:
          print("len filtered batch y is less than batch size")
          break

        # IPython.embed()
        # raise ValueError("asdf")

        max_observed_reward_during_training = max(max_observed_reward_during_training, np.max(np.array(filtered_batch_y)))
        max_observed_rewards.append(max_observed_reward_during_training)


        if j%1 == 0:
          print("Batch {}".format(j+1))

        model_fitting_loss = algorithm.fit_data(num_opt_steps, growing_dataset, complement_unsupervised_dataset, opt_batch_size)

        model_fitting_losses.append(model_fitting_loss)


        if record_validation_loss and algorithm.name != "RandomBatch":
          with torch.no_grad():   
            validation_loss = algorithm.reward_model.get_loss(all_data_X, all_data_y).detach().numpy()

          validation_losses.append(validation_loss)

    return algorithm, max_observed_rewards, model_fitting_losses, true_max_reward, validation_losses, extra_info




def run_experiment(lambda_reward_max, dataset_info, l1, num_batches = 5, num_opt_steps = 5000, 
  batch_size = 3, opt_batch_size = 10, 
  dataset_name = 'MultiValleyOneDim', MLP = True, random_init = True, 
  representation_layer_sizes = [20], algorithm_type = "RandomBatch", l2_regularizer = .1, 
  range_regularizer = 0, num_ensemble_elements = 10
  ):
  


  #train_dataset, unsupervised_dataset = get_experiments_dataset(dataset_name, dataset_info)

  dataset = get_experiments_dataset(dataset_name, dataset_info)
  

  # all_data_X, all_data_y = get_batches( train_dataset, 100000) 
  # dataset = DataSet(pd.DataFrame(all_data_X), pd.DataFrame(all_data_y))


  if algorithm_type == "MeanOptimism":
    #reward_model = TorchRewardsModel(random_init = random_init, MLP = MLP, dim = dataset.dimension, representation_layer_size = representation_layer_size)
    reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')

    algorithm = MeanOptimism(reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer)

  elif algorithm_type == "MaxOptimism":
    #reward_model = TorchRewardsModel(random_init = random_init, MLP = MLP, dim = dataset.dimension, representation_layer_size = representation_layer_size)
    reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')

    algorithm = MaxOptimism(reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer)

  elif algorithm_type == "HingePNormOptimism":
    #reward_model = TorchRewardsModel(random_init = random_init, MLP = MLP, dim = dataset.dimension, representation_layer_size = representation_layer_size)
    reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')
    
    algorithm = HingePNormOptimism(reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer)

  elif algorithm_type == "RandomBatch":
    algorithm = RandomBatch()

  elif algorithm_type == "EnsembleOptimism":
    reward_models = [TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu') for _ in range(num_ensemble_elements)]
    algorithm = EnsembleOptimism(reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer)

  elif algorithm_type == "EnsembleOptimismNoiseY":
    reward_models = [TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu') for _ in range(num_ensemble_elements)]
    algorithm = EnsembleOptimism(reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer, noise_injection = True)


  elif algorithm_type == "EnsembleSequentialBatchOptimism":
    reward_models = [TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu') for _ in range(num_ensemble_elements)]
    algorithm = EnsembleSequentialBatchOptimism(reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer)


  elif algorithm_type == "EnsembleSequentialBatchOptimismNoiseY":
    reward_models = [TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu') for _ in range(num_ensemble_elements)]
    algorithm = EnsembleSequentialBatchOptimism(reward_models, l1, lambda_reward_max, l2_regularizer, range_regularizer, noise_injection = True)


  elif algorithm_type == "DeterminantsDiversityOptimism":
    reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')    
    algorithm = DeterminantOptimism(reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer)



  elif algorithm_type == "SequentialBatchOptimism":
    reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')    
    algorithm = SequentialBatchOptimism(reward_model, l1, lambda_reward_max, l2_regularizer, range_regularizer)

  else:
    raise ValueError("Unknown algorithm_type!")

  algorithm, max_observed_rewards, model_fitting_losses, true_max_reward, validation_losses, extra_info = train_optimistic_reward_one_dim(algorithm, dataset, 
    num_batches = num_batches, num_opt_steps = num_opt_steps, 
    batch_size = batch_size, opt_batch_size = opt_batch_size)


  #IPython.embed()

  ranks = [evaluate_rank_observed_reward(dataset, max_observed_reward_during_training) for max_observed_reward_during_training in max_observed_rewards]


  return ranks, max_observed_rewards, true_max_reward, validation_losses, extra_info



@ray.remote
def run_experiment_remote(lambda_reward_max, dataset_info, l1, num_batches = 5, num_opt_steps = 5000, 
  batch_size = 3, opt_batch_size = 10, 
  dataset_name = 'MultiValleyOneDim', MLP = True, random_init = True, 
  representation_layer_sizes = [20], algorithm_type = "RandomBatch", l2_regularizer = .1, 
  range_regularizer = 0, num_ensemble_elements = 10
  ):

  # return run_experiment(lambda_reward_max, dataset_info, l1, num_batches, num_opt_steps, batch_size, opt_batch_size, MLP = MLP, 
  #   random_init = random_init, representation_layer_size = representation_layer_size, algorithm_type = algorithm_type)


  return run_experiment(lambda_reward_max, dataset_info, l1, num_batches = num_batches, num_opt_steps = num_opt_steps, 
  batch_size = batch_size, opt_batch_size = opt_batch_size, 
  dataset_name = dataset_name, MLP = MLP, random_init = random_init, 
  representation_layer_sizes = representation_layer_sizes, algorithm_type = algorithm_type, l2_regularizer = l2_regularizer, 
  range_regularizer = range_regularizer, num_ensemble_elements = num_ensemble_elements
  )


def process_results(results):
    ranks = [a for (a,b,c,d,e) in results]
    max_observed_rewards = [b for (a,b,c,d,e) in results]
    max_dataset_value = [c for (a,b,c, d,e) in results]

    ranks_mean = np.mean(ranks, 0)
    ranks_std = np.std(ranks, 0)

    max_observed_rewards_mean = np.mean(max_observed_rewards, 0)
    max_observed_rewards_std = np.std(max_observed_rewards, 0)
    max_true_reward = np.mean(max_dataset_value)

    return ranks_mean, ranks_std, max_observed_rewards_mean.squeeze(), max_observed_rewards_std.squeeze(), max_true_reward, results




if __name__ == "__main__":

  dataset_name = sys.argv[1]
  algorithm_type = sys.argv[2]
  RANDOM_SEED = int(sys.argv[3])
  num_batches = int(sys.argv[4])
  num_experiments = int(sys.argv[5])
  batch_size = int(sys.argv[6])
  parallel_experiment_batching_size = int(sys.argv[7])

  representation_layer_sizes_string = sys.argv[8]
  if representation_layer_sizes_string == "0":
    representation_layer_sizes = []
  else:
    representation_layer_sizes = representation_layer_sizes_string.split("_")
    representation_layer_sizes = [int(x) for x in representation_layer_sizes]

  # IPython.embed()
  # raise ValueError("asldkfm")
  torch.manual_seed(RANDOM_SEED)
  random.seed(RANDOM_SEED)
  np.random.seed(RANDOM_SEED)



  #batch_size = 3
  lambdas = [0, 0.001, 0.01, 0.1]#, 0.001]#, 0.01, 0.1]
  l1_settings = [False]#, True]# [False, True]
  #representation_layer_sizes = [100, 10]
  #"HingePNormOptimism"#'RandomBatch'#'DeterminantsDiversityOptimism'#, 'HingePNormOptimism'] #['MeanOptimism', 'MaxOptimism', 'HingePNormOptimism']


  if algorithm_type == "RandomBatch":
    model_parameter_configurations = [dict([("algorithm_type", "RandomBatch"), ("lambda_reward_max", 0), ("l1", False) ])]

  elif algorithm_type == "EnsembleOptimism":
    model_parameter_configurations = [dict([("algorithm_type", "EnsembleOptimism"), ("lambda_reward_max", 0), ("l1", False) ])]

  elif algorithm_type == "EnsembleSequentialBatchOptimism":
    model_parameter_configurations = [dict([("algorithm_type", "EnsembleSequentialBatchOptimism"), ("lambda_reward_max", 0), ("l1", False) ])]    

  elif algorithm_type == "EnsembleOptimismNoiseY":
    model_parameter_configurations = [dict([("algorithm_type", "EnsembleOptimismNoiseY"), ("lambda_reward_max", 0), ("l1", False) ])]    

  elif algorithm_type == "EnsembleSequentialBatchOptimismNoiseY":
    model_parameter_configurations = [dict([("algorithm_type", "EnsembleSequentialBatchOptimismNoiseY"), ("lambda_reward_max", 0), ("l1", False) ])]    


  elif algorithm_type == "SequentialBatchOptimism":
    lambdas.remove(0)
    model_parameter_configurations = [dict([("algorithm_type", algorithm_type), ("lambda_reward_max", lambda_val), ("l1", l1) ]) for (lambda_val, l1) in list(itertools.product(lambdas, l1_settings))]


  else:
    #model_parameter_configurations = list(itertools.product(lambdas, l1_settings, [algorithm_type]))#, [dict([("noisy", False)]), dict([("noisy", True)])])
    model_parameter_configurations = [dict([("algorithm_type", algorithm_type), ("lambda_reward_max", lambda_val), ("l1", l1) ]) for (lambda_val, l1) in list(itertools.product(lambdas, l1_settings))]

  path = os.getcwd()

  results_directory = "{}/results".format(path)


  #colors = ["red", "blue", "green", "orange", "black", "gray", "violet", ]
  num_colors = len(model_parameter_configurations) + 1
  color_numbers = np.linspace(0,1, num_colors)
  color_map = cm.get_cmap('viridis', num_colors)
  colors  = [color_map(color_numbers[i]) for i in range(num_colors)]

  dataset_info = dict([])
  dataset_info["noisy"] = False


  


  results_dictionary = dict([])
  l1 = False
  
  name_stub = get_experiment_name_stub(dataset_name, batch_size, algorithm_type, l1, num_batches, representation_layer_sizes, num_experiments, RANDOM_SEED)


  
  if ZIP_FILE:
    zip_results_filename = "{}.zip".format(name_stub)
  else:
    zip_results_filename = "{}.p".format(name_stub)  
  data_log_file_name = "{}/{}".format(results_directory, zip_results_filename)

  ### Check if experiment is already there.

  if os.path.exists(data_log_file_name):
    print("Experiment {} is already logged in.".format(name_stub))
    exit()



  # IPython.embed()
  # raise ValueError("asdflkm")

  for parameters  in model_parameter_configurations:
    lambda_reward_max = parameters["lambda_reward_max"]
    l1 = parameters["l1"]
    algorithm_type = parameters["algorithm_type"]

    if algorithm_type == "RandomBatch":



      if USE_RAY:
        # if algorithm_type != "DeterminantsDiversityOptimism":
        #   results = [run_experiment_remote.remote(0,dataset_info, l1 = False, dataset_name = dataset_name, num_batches = num_batches, batch_size = batch_size, 
        #     algorithm_type = "RandomBatch", representation_layer_sizes = representation_layer_sizes) for _ in range(num_experiments)]
        #   results = ray.get(results)
        # else:
        results = []
        num_experiments_so_far = 0

        while num_experiments_so_far < num_experiments:
            num_partial_experiments = min(num_experiments - num_experiments_so_far, parallel_experiment_batching_size  )
            partial_results = [run_experiment_remote.remote(0,dataset_info, l1 = False, dataset_name = dataset_name, num_batches = num_batches, batch_size = batch_size, 
              algorithm_type = "RandomBatch", representation_layer_sizes = representation_layer_sizes) for _ in range(num_partial_experiments)]

            partial_results = ray.get(partial_results)
            results += partial_results
            num_experiments_so_far += num_partial_experiments
      else:
        results = [run_experiment(0,dataset_info, l1 = False, dataset_name = dataset_name, num_batches = num_batches, 
          batch_size = batch_size, algorithm_type = "RandomBatch", representation_layer_sizes = representation_layer_sizes) for _ in range(num_experiments)]

      print("Finished ", "RandomBatch", " ", dataset_name)

      results_dictionary["RandomBatch"] = process_results(results)

      true_results_size = len(results_dictionary["RandomBatch"][0])




    elif algorithm_type == "EnsembleOptimism" or algorithm_type == "EnsembleSequentialBatchOptimism" or algorithm_type == "EnsembleOptimismNoiseY" or algorithm_type == "EnsembleSequentialBatchOptimismNoiseY":

      if USE_RAY:
        results = []
        num_experiments_so_far = 0
        while num_experiments_so_far < num_experiments:
          num_partial_experiments = min(num_experiments - num_experiments_so_far, parallel_experiment_batching_size  )
          partial_results = [run_experiment_remote.remote(0,dataset_info, l1 = False, dataset_name = dataset_name, num_batches = num_batches, 
            batch_size = batch_size, algorithm_type = algorithm_type, representation_layer_sizes = representation_layer_sizes) for _ in range(num_partial_experiments)]
          partial_results = ray.get(partial_results)
          results += partial_results
          num_experiments_so_far += num_partial_experiments



      else:
        results = [run_experiment(0,dataset_info, l1 = False, dataset_name = dataset_name, num_batches = num_batches, 
          batch_size = batch_size, algorithm_type = algorithm_type, representation_layer_sizes = representation_layer_sizes) for _ in range(num_experiments)]

      print("Finished ", algorithm_type, " ", dataset_name)


      results_dictionary[algorithm_type] = process_results(results)

      true_results_size = len(results_dictionary[algorithm_type][0])



    else:

       
        print("Started expriment {} lambda {} l1 {} {}".format(dataset_name, lambda_reward_max, l1, algorithm_type))

        if USE_RAY:
          results = []
          num_experiments_so_far = 0

          while num_experiments_so_far < num_experiments:
            num_partial_experiments = min(num_experiments - num_experiments_so_far, parallel_experiment_batching_size  )

            partial_results =  [run_experiment_remote.remote(lambda_reward_max, dataset_info, l1, dataset_name = dataset_name, num_batches = num_batches, 
              batch_size = batch_size, algorithm_type = algorithm_type, representation_layer_sizes = representation_layer_sizes) for _ in range(num_partial_experiments)]
            partial_results  = ray.get(partial_results)
            results += partial_results
            num_experiments_so_far += num_partial_experiments

        else:
          results = [run_experiment(lambda_reward_max, dataset_info, l1, dataset_name = dataset_name, num_batches = num_batches, 
            batch_size = batch_size, algorithm_type= algorithm_type, representation_layer_sizes = representation_layer_sizes) for _ in range(num_experiments)]


        if lambda_reward_max == 0 :
          results_name = "Greedy-l1-{}-{}".format(l1, algorithm_type)
        else:
          results_name = "Lambda-{}-l1-{}-{}".format(lambda_reward_max, l1, algorithm_type)
        
        results_dictionary[results_name] = process_results(results)
        true_results_size = len(results_dictionary[results_name][0])

        print("Finished ", results_name, " ", dataset_name)






  log_experiment_data(results_dictionary, name_stub, results_directory, is_zip_file = ZIP_FILE)



  if PLOT_PRELIMINARY:

    filename = "{}/ranks_{}.png".format(results_directory, name_stub)
    title_stub = ""
    plot_ranks(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = None)


    # ### PLOT RANKS
    filename = "{}/rankscloseup_{}.png".format(results_directory, name_stub)
    title_stub = "Ranks Closeup"
    plot_ranks(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = 30)


    # ### PLOT RANKS CLOSEUP
    filename = "{}/rankssupercloseup_{}.png".format(results_directory, name_stub)
    title_stub = "Ranks SuperCloseup"
    plot_ranks(dataset_name, title_stub, results_dictionary, filename, upper_y_lim = 10)



    # ### PLOT RANKS SUPER CLOSEUP
    filename = "{}/rewards_{}.png".format(results_directory, name_stub)
    title_stub = "Rewards"
    plot_rewards(dataset_name, title_stub, results_dictionary, filename)


