import matplotlib

matplotlib.use('Agg')

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 

from matplotlib import cm
from matplotlib.ticker import LinearLocator

import pandas as pd

import torch
from models import TorchRewardsModel
from datasets import get_dataset
#from algorithms import *
from datasets import SVMDataset, get_dataset, get_batches, GrowingNumpyDataSet, DataSetUnsupervised, DataSet
from algorithms import get_max_reward_batch, evaluate_rank_observed_reward

import itertools
import IPython



def train_optimistic_reward_colors(reward_model, dataset, num_batches, num_opt_steps, 
    batch_size, opt_batch_size = 10, lambda_reward_max = 0.01, 
    l2_regularizer = 0, range_regularizer = 0, verbose = False, l1 = False, colorplot_frequency = 10):
    

    #complement_batch_X, complement_batch_y = get_batches( dataset, 30000) 
    all_data_X, all_data_y = get_batches( dataset, 30000) 


    complement_batch_X = np.copy(all_data_X)
    complement_batch_y = np.copy(all_data_y)


    complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))
    plot_dataset_values(-5, 8, -5, 8, .01, dataset, filename = "dataset_landscape_continuous")


    growing_dataset = GrowingNumpyDataSet()
    max_observed_reward_during_training = -float("inf")
    max_observed_rewards = []
    model_fitting_losses = []

    for j in range(num_batches):
        optimizer = torch.optim.Adam(reward_model.network.parameters(), lr = 0.01)
        filtered_batch_X, filtered_batch_y, complement_batch_X, complement_batch_y  = get_max_reward_batch(reward_model, complement_supervised_dataset, batch_size)
        growing_dataset.add_data(filtered_batch_X, filtered_batch_y)

        complement_unsupervised_dataset = DataSetUnsupervised(pd.DataFrame(complement_batch_X))  
        complement_supervised_dataset = DataSet(pd.DataFrame(complement_batch_X), pd.DataFrame(complement_batch_y))

        max_observed_reward_during_training = max(max_observed_reward_during_training, max(filtered_batch_y))
        max_observed_rewards.append(max_observed_reward_during_training)



        if j%colorplot_frequency == 0:
          #### Plot the landscape
          plot_model_values(-5, 8, -5, 8, .01, reward_model, filename = "reward_model_before_opt_iter_{}".format(j))
          plot_model_values(-5, 8, -5, 8, .01, reward_model, filename = "reward_model_before_normal_opt_iter_{}".format(j), view = "normal")

          plot_model_values(-5, 8, -5, 8, .01, reward_model, special_points_X = all_data_X, 
            special_points_response = all_data_y , all_scatter_plot = True, filename = "reward_model_vs_true_before_opt_iter_{}".format(j))
          plot_model_values(-5, 8, -5, 8, .01, reward_model, special_points_X = all_data_X, 
            special_points_response = all_data_y , all_scatter_plot = True, filename = "reward_model_vs_true_before_normal_opt_iter_{}".format(j), view = "normal")




        for i in range(num_opt_steps):
            batch_X, batch_y = get_batches(growing_dataset, opt_batch_size)
            unsupervised_batch_X = get_batches(complement_unsupervised_dataset, opt_batch_size)
            optimizer.zero_grad()
            if verbose:
                print("Global batch num ", j, " opt step ", i)
            if l1:
                loss = reward_model.get_loss_l1(batch_X, batch_y, l2_regularizer = l2_regularizer, range_regularizer = range_regularizer)
            else:
                loss = reward_model.get_loss(batch_X, batch_y, l2_regularizer = l2_regularizer, range_regularizer = range_regularizer)

            if i == num_opt_steps-1:
                with torch.no_grad():
                    model_fitting_losses.append(loss.detach().numpy())

            predictions = reward_model.get_reward(unsupervised_batch_X)
            loss -= lambda_reward_max*torch.mean(predictions)
            #prediction_range_clipping_loss = torch.mean(torch.log(torch.exp(predictions-1) + torch.exp(-predictions-1)))
            #loss += range_regularizer*prediction_range_clipping_loss
            if verbose or j == num_batches-1 and i == num_opt_steps-1:
                #print("prediction range clipping loss ", prediction_range_clipping_loss)
                print("loss" , loss)
                #print("reward predictions ", predictions)
            loss.backward()
            optimizer.step()



        if j%colorplot_frequency == 0:
          #### Plot the landscape
          plot_model_values(-5, 8, -5, 8, .01, reward_model, filename = "reward_model_after_opt_iter_{}".format(j))
          plot_model_values(-5, 8, -5, 8, .01, reward_model, filename = "reward_model_after_normal_opt_iter_{}".format(j), view = "normal")

          plot_model_values(-5, 8, -5, 8, .01, reward_model, special_points_X = all_data_X, 
            special_points_response = all_data_y , all_scatter_plot = True, filename = "reward_model_vs_true_after_opt_iter_{}".format(j))
          plot_model_values(-5, 8, -5, 8, .01, reward_model, special_points_X = all_data_X, 
            special_points_response = all_data_y , all_scatter_plot = True, filename = "reward_model_vs_true_after_normal_opt_iter_{}".format(j), view = "normal")

          # plot_model_values(-1, 1, -1, 1, .01, reward_model, special_points_X = new_batch_X, 
          #   special_points_response = new_batch_y , all_scatter_plot = False, filename = "test")

        #max_rewards_estimated_value, estimated_index_real_value, max_reward_value, estimated_value_of_max_reward, resulting_rank = evaluate_reward_model(reward_model, train_dataset)

    return reward_model, max_observed_rewards, model_fitting_losses





def get_grid_to_eval(bottom_x, top_x, bottom_y, top_y, epsilon):
  num_evaluation_points_x = int((top_x - bottom_x)/epsilon)
  num_evaluation_points_y = int((top_y - bottom_y)/epsilon)

  mesh_x = np.linspace(bottom_x, top_x, num_evaluation_points_x)
  mesh_y = np.linspace(bottom_y, top_y, num_evaluation_points_y)
  mesh = np.array(list(itertools.product(list(mesh_x), list(mesh_y))))

  return mesh, num_evaluation_points_x, num_evaluation_points_y
  




def plot_dataset_values(bottom_x, top_x, 
  bottom_y, top_y, epsilon, 
  generative_dataset, 
  title = "", filename = ""):
  #preds = model.get_thresholded_predictions(mesh, threshold = .5)#.numpy()
  fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
 
  #IPython.embed()
  mesh, num_evaluation_points_x, num_evaluation_points_y  = get_grid_to_eval(bottom_x, top_x, bottom_y, top_y, epsilon)#.numpy()


  preds  = generative_dataset.compute_reward(mesh)

  first_dim_mesh = mesh[:,0]
  second_dim_mesh = mesh[:,1]

  # if len(special_points) != 0:
  #   special_first_dim = special_points[:,0]
  #   special_second_dim = special_points[:, 1]
  #   plt.plot(special_first_dim, special_second_dim, "o", color = "blue", label = "Pseudo-labeled points")

  first_dim_mesh = first_dim_mesh.reshape((num_evaluation_points_x, num_evaluation_points_y))
  second_dim_mesh = second_dim_mesh.reshape((num_evaluation_points_x, num_evaluation_points_y))
  preds = preds.reshape((num_evaluation_points_x, num_evaluation_points_y))


  # plt.plot_surface(first_dim_mesh, second_dim_mesh, preds)
  surf = ax.plot_surface(first_dim_mesh, second_dim_mesh, preds, cmap=cm.coolwarm,
                         linewidth=0, antialiased=False)



  ax.zaxis.set_major_locator(LinearLocator(10))
  # A StrMethodFormatter is used automatically
  ax.zaxis.set_major_formatter('{x:.02f}')

  # Add a color bar which maps values to colors.
  fig.colorbar(surf, shrink=0.5, aspect=5)


  ax.set_zlabel("Reward Values")
  ax.legend(loc = "lower right")
  #ax.title( title)

  plt.savefig("./surface_results/{}.png".format(filename))

  plt.close('all')










def plot_model_values(bottom_x, top_x, bottom_y, top_y, epsilon, model, 
  special_points_X = [], special_points_response = [], 
  title = "", stretch = True, all_scatter_plot = False, filename = "", view = "frontal"):
  #preds = model.get_thresholded_predictions(mesh, threshold = .5)#.numpy()
  fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
 
  if len(special_points_X) != 0:

    first_dim_special = special_points_X[:,0]
    second_dim_special = special_points_X[:,1]
    #IPython.embed()
    ax.scatter(first_dim_special, second_dim_special, special_points_response, marker = "x", color = "red" )

    bottom_x = min(bottom_x, np.min(first_dim_special))
    top_x = max(top_x, np.max(first_dim_special))

    bottom_y = min(bottom_y, np.min(second_dim_special))
    top_y = max(top_y, np.max(second_dim_special))


  # if len(special_points) != 0:
  #   special_first_dim = special_points[:,0]
  #   special_second_dim = special_points[:, 1]
  #   plt.plot(special_first_dim, special_second_dim, "o", color = "blue", label = "Pseudo-labeled points")

  if all_scatter_plot:
    with torch.no_grad():
      preds = model.get_reward(special_points_X)
    preds = preds.numpy()

    ax.scatter(first_dim_special, second_dim_special, preds, marker = "x", color = "blue")
  
  else:

    #IPython.embed()
    mesh, num_evaluation_points_x, num_evaluation_points_y  = get_grid_to_eval(bottom_x, top_x, bottom_y, top_y, epsilon)#.numpy()

    with torch.no_grad():
      preds =   model.get_reward(mesh)

    preds = preds.numpy()
    first_dim_mesh = mesh[:,0]
    second_dim_mesh = mesh[:,1]


    first_dim_mesh = first_dim_mesh.reshape((num_evaluation_points_x, num_evaluation_points_y))
    second_dim_mesh = second_dim_mesh.reshape((num_evaluation_points_x, num_evaluation_points_y))
    preds = preds.reshape((num_evaluation_points_x, num_evaluation_points_y))


    # plt.plot_surface(first_dim_mesh, second_dim_mesh, preds)
    surf = ax.plot_surface(first_dim_mesh, second_dim_mesh, preds, color = "blue", cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

    ax.zaxis.set_major_locator(LinearLocator(10))
    # A StrMethodFormatter is used automatically
    ax.zaxis.set_major_formatter('{x:.02f}')
  
    # Add a color bar which maps values to colors.
    fig.colorbar(surf, shrink=0.5, aspect=5)


  if view == "frontal":
    ax.view_init(elev = 0, azim = 30)
  elif view == "normal":
    pass
  else:
    raise ValueError("Unrecognized 3d plot view!")
  ax.set_zlim(-.2, .7)


  ax.set_zlabel("Reward prediction")
  ax.legend(loc = "lower right")
  #ax.title( title)

  plt.savefig("./surface_results/{}.png".format(filename))

  plt.close("all")






  # plt.grid(True)

  # plt.show()




if __name__ == "__main__":

  dataset_name = "MultiSVM"

  dataset_name = 'MultiValley'


  random_init = True
  MLP = True
  representation_layer_size = 10
  (
      train_dataset,
      _,
      unsupervised_dataset,
  ) = get_dataset(dataset_name, 30, 10)


  reward_model = TorchRewardsModel(random_init = random_init, MLP = MLP, dim = train_dataset.dimension, representation_layer_size = representation_layer_size)
  new_batch_X, new_batch_y = train_dataset.get_batch(100)


  plot_dataset_values(-5, 8, -5, 8, .01, train_dataset, filename = "test_dataset")
  plot_model_values(-1, 1, -1, 1, .01, reward_model, special_points_X = new_batch_X, special_points_response = new_batch_y , all_scatter_plot = True, filename = "test")


  l1 = False
  lambda_reward_max = 0.01
  # IPython.embed()
  # raise ValueError("Asfdlkm")
  reward_model, max_observed_rewards, model_fitting_losses = train_optimistic_reward_colors(reward_model, train_dataset, num_batches = 200, num_opt_steps = 2000, 
      batch_size = 3, opt_batch_size = 10, lambda_reward_max = lambda_reward_max, 
      l2_regularizer = 0, range_regularizer = 0, verbose = False, l1 = l1, colorplot_frequency = 10)


  ranks = [ evaluate_rank_observed_reward(train_dataset, max_observed_reward_during_training) for max_observed_reward_during_training in max_observed_rewards] 
  print("Max observed rewards optimistic rewards - {}".format(lambda_reward_max), max_observed_rewards)
  print("Ranks ", ranks)
  print("L1 {} - ".format(l1))
  print("lambda - {}".format(lambda_reward_max))

