import os
import matplotlib
import pickle
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import ray

import IPython  
from onlinedatasets.datasets import get_batches
from onlinedatasets.models import TorchRewardsModelMultilayer


from run_experiments import get_experiments_dataset
from algorithms import train_simple_regression

RAY = False




def compute_losses(dataset_info, num_opt_steps = 5000, opt_batch_size = 10, 
  dataset_name = 'MultiValleyOneDim', 
  representation_layer_sizes = [20], verbose = True, 
  logging_frequency = 10, lr = 0.01, l2_lambda = 1):
  


  #train_dataset, unsupervised_dataset = get_experiments_dataset(dataset_name, dataset_info)

  dataset = get_experiments_dataset(dataset_name, dataset_info)

  IPython.embed()
  raise ValueError("asdflkm")
  

  reward_model = TorchRewardsModelMultilayer(dim = dataset.dimension, representation_layer_sizes = representation_layer_sizes, activation_type = 'relu')


  optimizer = torch.optim.Adam(reward_model.network.parameters(), lr)
  test_loss_list =[]
  for i in range(num_opt_steps):
        batch_X, batch_y = get_batches(dataset, opt_batch_size)

        batch_X = np.array(batch_X)
        batch_y = np.array(batch_y)
        #IPython.embed()
        #raise ValueError("asldfkm")
        optimizer.zero_grad()
        loss = reward_model.get_loss(batch_X, batch_y)
        if len(representation_layer_sizes) == 0:
          print("Added parameter regularization loss - {}".format(representation_layer_sizes))
          loss += l2_lambda*sum(p.pow(2.0).sum() for p in reward_model.network.parameters())

        if verbose and i % logging_frequency == 0:
            print("iteration ", i)

            test_batch_X, test_batch_y = get_batches(dataset, 100000000000000000)
            test_batch_X = np.array(test_batch_X)
            test_batch_y = np.array(test_batch_y)

            test_loss = reward_model.get_loss(test_batch_X, test_batch_y)
            test_loss = test_loss.detach().cpu().numpy()
            print("test loss ", test_loss)

            test_loss_list.append(np.float(test_loss))

        loss.backward()
        optimizer.step()

  return test_loss_list




@ray.remote
def compute_losses_remote(dataset_info, num_opt_steps = 5000, opt_batch_size = 10, 
  dataset_name = 'MultiValleyOneDim', 
  representation_layer_sizes = [20], verbose = True, logging_frequency = 10, lr = 0.01, l2_lambda = .1):
  return compute_losses(dataset_info, num_opt_steps, opt_batch_size, 
  dataset_name, representation_layer_sizes, verbose, logging_frequency, lr, l2_lambda)



def main():
  ylims = []
  dataset_names = []

  ylims += [[0,.6], [0,.4], [0,2], [0,.25]]
  dataset_names += ["MultiValleyOneDim", "MultiValleyOneDimHole", "OneHillValleyOneDim", "OneHillValleyOneDimHole"]

  #ylims += [[0,.001], [0,.001], [0,.1]]
  #ylims += [[0,.01], [0,.01], [0,.1]]
  # ylims += [[0,.03], [0,.03]]#, [0,10]]
  # dataset_names += ["BikeSharingDay", "BikeSharingHour"]#, "BlogFeedback"]
  
  # ylims += [[0,.005], [0.001,.003], [0.001,.003], [0.002,.008]]
  # dataset_names += ["VCAP", "HA1E", "MCF7", "A375"]

  random_seed = 102929
  torch.manual_seed(random_seed)
  random.seed(random_seed)
  np.random.seed(random_seed)
  torch.cuda.manual_seed(random_seed)
  torch.cuda.manual_seed_all(random_seed)



  for dataset_name, ylim in zip(dataset_names, ylims):

    representation_layer_sizes_list = [[], [10,5], [100,10], [300, 100]]


    opt_batch_size = 20
    num_opt_steps = 5000
    logging_frequency = 20
    num_experiments = 100




    datalog_file = "./results/regression_fit_{}.p".format(dataset_name)

    colors = ["blue", "red", "black", "orange"]
    filename = "./paper_results/regression_fit_{}.png".format(dataset_name) 

    plt.title("{} Regression Fit".format(dataset_name), fontsize = 15)

    plt.xlabel("Num Batches")
    plt.ylabel("Mean Square Loss")


    # if os.path.exists(datalog_file):
    #   all_data = pickle.load( open(datalog_file, "rb") )
    #   #IPython.embed()
    # else:
    if True:
      #dataset_name = "MultiValleyOneDim"

      if num_opt_steps%logging_frequency != 0:
        raise ValueError("The optimization batch size does not divide the number of opt steps")

      dataset_info = dict([])
      dataset_info["noisy"] = False


      all_data  = []

      for representation_layer_sizes,i in zip(representation_layer_sizes_list, range(len(representation_layer_sizes_list))):



        test_loss_list_results = []

        if not RAY:

          for _ in range(num_experiments):

            test_loss_list = compute_losses(dataset_info, num_opt_steps = num_opt_steps, opt_batch_size = opt_batch_size, 
                    dataset_name = dataset_name, 
                    representation_layer_sizes = representation_layer_sizes, logging_frequency = logging_frequency, lr=.01)
            test_loss_list_results.append(test_loss_list)

        else:
            test_loss_list_results = [compute_losses_remote.remote(dataset_info, num_opt_steps = num_opt_steps, opt_batch_size = opt_batch_size, 
                    dataset_name = dataset_name, 
                    representation_layer_sizes = representation_layer_sizes, logging_frequency = logging_frequency, lr=.01) for _ in range(num_experiments)]
            test_loss_list_results  = ray.get(test_loss_list_results)



        all_data.append((representation_layer_sizes, test_loss_list_results))
        


    for representation_layer_sizes,i in zip(representation_layer_sizes_list, range(len(representation_layer_sizes_list))):

      results_name = ""
      if len(representation_layer_sizes) == 0:
        results_name= "Linear"
      else:
        results_name = "NN {}-{}".format(representation_layer_sizes[0], representation_layer_sizes[1])

      mean_loss_values = np.mean(all_data[i][1], axis = 0)
      std_loss_values = np.std(all_data[i][1], axis = 0)

      num_batches_list = logging_frequency*(np.arange(num_opt_steps/logging_frequency)+1)
      plt.plot(num_batches_list, mean_loss_values, linewidth = 3, color = colors[i], label = results_name)
      plt.fill_between(num_batches_list, mean_loss_values - .5*std_loss_values, 
         mean_loss_values + .5*std_loss_values, color = colors[i], alpha = .2)



      

    plt.legend(loc = "upper right", fontsize = 13)
    plt.ylim(ylim[0], ylim[1])
    #plt.yscale("log")
    plt.savefig(filename)

    plt.close("all")



    pickle.dump( all_data, open(datalog_file, "wb"))


   

    #IPython.embed()




if __name__ == "__main__":
    main()





