import matplotlib
matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt

#from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 

#from matplotlib import cm
#from matplotlib.ticker import LinearLocator

import pandas as pd

import torch
from onlinedatasets.models import TorchRewardsModel
from onlinedatasets.datasets import get_dataset
#from algorithms import *
from onlinedatasets.datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from algorithms import get_max_reward_batch, evaluate_rank_observed_reward, get_random_reward_batch

import os

import itertools
import IPython

import ray

ray.init()
USE_RAY = False


def train_optimistic_reward_colors_one_dim(reward_model, dataset, num_batches, num_opt_steps, 
    batch_size, opt_batch_size = 10, lambda_reward_max = 0.01, 
    l2_regularizer = .1, range_regularizer = 0, verbose = False, l1 = False, 
    colorplot_frequency = 10, lambda_constrain = 1, random_batch = False):
    #os.system('export DISPLAY=:0.0')
    #import matplotlib
    #matplotlib.use('Agg')
    #import matplotlib.pyplot as plt
    if not random_batch:
      results_folder_name = "line_results/l1_{}_lambdaregmax_{}".format(l1, lambda_reward_max)
    else:
      results_folder_name = "line_results/random"

    plot_dir = "./{}".format(results_folder_name)

    if not os.path.isdir(plot_dir):
      os.makedirs(plot_dir)


    #complement_batch_X, complement_batch_y = get_batches( dataset, 30000) 
    all_data_X, all_data_y = get_batches( dataset, 1000) 


    #IPython.embed()

    complement_batch_X = np.copy(all_data_X)
    complement_batch_y = np.copy(all_data_y)


    complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))

    #plot_dataset_values_one_dim(-5, 8, -5, 8, .01, dataset, filename = "dataset_landscape_continuous")


    growing_dataset = GrowingNumpyDataSet()
    max_observed_reward_during_training = -float("inf")
    max_observed_rewards = []
    model_fitting_losses = []

    for j in range(num_batches):
        optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
  
        if not random_batch:
          filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y  = get_max_reward_batch(reward_model, complement_supervised_dataset, batch_size)
        if random_batch:
          filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y  = get_random_reward_batch(complement_supervised_dataset, batch_size)

        growing_dataset.add_data(filtered_batch_X, filtered_batch_y)

        complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  
        complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))

        max_observed_reward_during_training = max(max_observed_reward_during_training, max(filtered_batch_y))
        max_observed_rewards.append(max_observed_reward_during_training)



        if j%colorplot_frequency == 0:
          #### Plot the landscape
            plot_dataset_values_one_dim(-10, 10, .01, dataset, filename = "curve_scatter_iter_{}".format(j), all_queries = growing_dataset.dataset_X, current_query = filtered_batch_X, folder_name = results_folder_name)
            plot_dataset_values_one_dim(-10, 10, .01, dataset, filename = "scatter_scatter_iter_{}".format(j), all_dataset = all_data_X, all_queries = growing_dataset.dataset_X, current_query = filtered_batch_X, folder_name = results_folder_name)
            plot_dataset_values_one_dim(-10, 10, .01, dataset, filename = "model_scatter_iter_{}".format(j), model =  reward_model, all_queries = growing_dataset.dataset_X, current_query = filtered_batch_X, folder_name = results_folder_name)
            plot_dataset_values_one_dim(-10, 10, .01, dataset, filename = "model_scatter_with_all_iter_{}".format(j), model =  reward_model, all_dataset = all_data_X, all_queries = growing_dataset.dataset_X, current_query = filtered_batch_X, folder_name = results_folder_name)


        reward_model.reset_weights()
        for i in range(num_opt_steps):
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, opt_batch_size)
            optimizer.zero_grad()
            if verbose:
                print("Global batch num ", j, " opt step ", i)
            if l1:
                loss = reward_model.get_loss_l1(batch_X, batch_y, l2_regularizer = l2_regularizer, range_regularizer = range_regularizer)
            else:
                loss = reward_model.get_loss(batch_X, batch_y, l2_regularizer = l2_regularizer, range_regularizer = range_regularizer)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_losses.append(loss.detach().numpy())

            predictions = reward_model.get_reward(unsupervised_batch_X)
            loss -= lambda_reward_max*torch.mean(predictions)
            #loss += lambda_constrain*torch.mean(torch.max(predictions-1, torch.zeros(predictions.shape))) + lambda_constrain*torch.mean(torch.max(1-predictions, torch.zeros(predictions.shape)))
            #prediction_range_clipping_loss = torch.mean(torch.log(torch.exp(predictions-1) + torch.exp(-predictions-1)))
            #loss += range_regularizer*prediction_range_clipping_loss
            if verbose or j == num_batches-1 and i == num_opt_steps-1:
                #print("prediction range clipping loss ", prediction_range_clipping_loss)
                print("loss" , loss)
                #print("reward predictions ", predictions)
            loss.backward()
            optimizer.step()


    return reward_model, max_observed_rewards, model_fitting_losses




def plot_dataset_values_one_dim(bottom_x, top_x, epsilon, 
  generative_dataset, 
  title = "", filename = "", model = None, all_dataset = [], all_queries = [], current_query = [], folder_name = "line_results"):

  if len(all_dataset) > 0:
    all_dataset_values  = generative_dataset.compute_reward(np.array(all_dataset))
    plt.scatter(all_dataset, all_dataset_values, marker = "x", color = "blue", label = "All Dataset")
    linspace = np.linspace(bottom_x, top_x, int((top_x - bottom_x)/epsilon))
    values = generative_dataset.compute_reward(linspace)
    plt.plot(linspace, values, linewidth = .1, color = "blue")


  else:
    linspace = np.linspace(bottom_x, top_x, int((top_x - bottom_x)/epsilon))
    values = generative_dataset.compute_reward(linspace)
    plt.plot(linspace, values, label = "Reward Values", color = "blue")


  if model != None:
    ### PLOT THE MODEL
    with torch.no_grad():
        linspace = np.linspace(bottom_x, top_x, int((top_x - bottom_x)/epsilon))
        #IPython.embed()
        model_preds =   model.get_reward(np.expand_dims(linspace, axis = 1))
        model_preds = model_preds.numpy()
    if len(all_dataset) > 0:
      plt.plot(linspace, model_preds, linewidth = .2, color = "violet", label = "Model values")
    else:
      plt.plot(linspace, model_preds, label = "Model Values", color = "violet")



  if len(all_queries) > 0:
    all_queries_values  = generative_dataset.compute_reward(np.array(all_queries))
    plt.scatter(all_queries, all_queries_values, marker = "x", color = "black", label = "Old Queries")

  if len(current_query) >0 :
    current_values =  generative_dataset.compute_reward(np.array(current_query))
    plt.scatter(current_query, current_values, marker = "x", color = "red", label = "Current Query")
  
  if title != "":
    plt.title(title)

  plt.legend(loc = "upper left")
  plt.savefig("./{}/{}.png".format(folder_name, filename))
  plt.close('all')





def run_experiment(lambda_reward_max, info, l1, num_batches = 200, num_opt_steps = 5000, 
  batch_size = 3, opt_batch_size = 10, colorplot_frequency = 5, random_batch = False):
  dataset_name = 'MultiValleyOneDim'
  random_init = True
  MLP = True
  representation_layer_size = 20
  

  (
      _,
      _,
      train_dataset,
      _,
      _,
  ) = get_dataset(dataset_name, info = info)

  reward_model = TorchRewardsModel(random_init = random_init, MLP = MLP, dim = train_dataset.dimension, representation_layer_size = representation_layer_size)
  new_batch_X, new_batch_y = train_dataset.get_batch(100)

  reward_model, max_observed_rewards, model_fitting_losses = train_optimistic_reward_colors_one_dim(reward_model, train_dataset, 
      num_batches = num_batches, num_opt_steps = num_opt_steps, 
      batch_size = batch_size, opt_batch_size = opt_batch_size, lambda_reward_max = lambda_reward_max, 
      l2_regularizer = 0, range_regularizer = 0, verbose = False, l1 = l1, colorplot_frequency = colorplot_frequency, random_batch = random_batch)



@ray.remote
def run_experiment_remote(lambda_reward_max, info, l1, num_batches = 200, num_opt_steps = 5000, 
  batch_size = 3, opt_batch_size = 10, colorplot_frequency = 5, random_batch = False):
  return run_experiment(lambda_reward_max, info, l1, num_batches, num_opt_steps, 
  batch_size, opt_batch_size, colorplot_frequency, random_batch = random_batch)


if __name__ == "__main__":


  lambdas_reward_max = [0, 0.001, 0.01, 0.1]
  infos = [dict([("noisy", False)]), dict([("noisy", True)])]
  l1_values = [False, True]



  parameter_configurations = itertools.product(lambdas_reward_max, l1_values, infos)

  info_random_batch = dict()
  info_random_batch["noisy"] = False
  l1_random_batch =  False
  lambda_reward_max_random_batch = 0.01

  ### START BY RUNNING THE RANDOM BATCH EXPERIMENT
  run_experiment(lambda_reward_max_random_batch, info_random_batch, l1_random_batch, random_batch = True)

  if USE_RAY:
          
      plotting_run = [run_experiment_remote.remote(lambda_reward_max, info, l1) for (lambda_reward_max, l1, info) in parameter_configurations]
      plotting_run = ray.get(plotting_run)

  else:
    for (lambda_reward_max, l1, info) in parameter_configurations:
        run_experiment(lambda_reward_max, info, l1)
