# -*- coding: utf-8 -*-
"""plane.ipynb

Automatically generated by Colaboratory.

"""


import torch
import numpy as np
import torch.nn.functional as F
import tqdm
import tabulate
import os
import argparse
import logging

from torch.utils import data
from torchvision import datasets, transforms

from src.tlp_model_fusion.utils import average_meter
from src.tlp_model_fusion import train_models
from src.tlp_model_fusion import datasets as mydatasets
from src.tlp_model_fusion import fuse_models
from src.tlp_model_fusion import model
from src.tlp_model_fusion import resnet_models
from src.tlp_model_fusion import tlp_fusion
from src.tlp_model_fusion import vgg_models
from src.tlp_model_fusion import init
from fuse_models import get_model
from src.tlp_rnn_fusion import fuse_rnn_models
from src.tlp_rnn_fusion import train_rnn
from init import make_dirs
from src.tlp_rnn_fusion.rnn_models import RNNWithDecoder, RNNWithEncoderDecoder, LSTMWithDecoder, LSTMWithEncoderDecoder
from src.tlp_rnn_fusion import embedding

### Load neural networks
def load_model(model_name, model_path):
  state_dict = torch.load(model_path)
  if model_name in ['RNN', 'LSTM', 'rnn', 'lstm']:
      model = fuse_rnn_models.get_model(model_name, input_dim=28, config=state_dict['config'], encoder=False)
  else:
      model = fuse_models.get_model(model_name, state_dict['config'])
  model.load_state_dict(state_dict['model_state_dict'])
  return model 

### Get coordinate
def get_xy(point, origin, vector_x, vector_y):
  return np.array([np.dot(point - origin, vector_x), np.dot(point - origin, vector_y)])

##### Test function for FC NN and deep CNN
def test(dataloader, model, model_name='FC'):
  #tbar = tqdm.tqdm(dataloader)
  total = 0
  correct = 0
  loss_logger = average_meter.AverageMeter()
  
  if torch.cuda.is_available():
    model = model.cuda()
  
  model.eval()

  for batch_idx, (images, labels) in enumerate(dataloader):
    if torch.cuda.is_available():
      images = images.cuda()
      labels = labels.cuda()
    if model_name == 'FC':
      logits = model(images.view(images.size(0), -1))
    # elif model_name in ['ImageRNN', 'ImageLSTM', 'RNN']:
    #   logits = model(images.squeeze())
    else:
      logits = model(images)
    
    loss = F.cross_entropy(logits, labels)
    prediction = torch.argmax(logits, dim=1)
    total += images.size(0)
    correct += torch.sum(labels == prediction)
    loss_logger.update(loss.item())

  accuracy = 100.0 * correct / total
  return {
      'loss': loss_logger.avg,
      'accuracy': accuracy,
  }

def loss_func(last_output, y):
  m = torch.nn.LogSoftmax(dim=1)
  loss = torch.nn.NLLLoss(reduction='mean')

  return (loss(m(last_output), y))


### Test function for RNN and LSTM
def test_rnn(dataloader, model, dataset_name='MNISTNorm', glove_embedding=None):
  total = 0
  correct = 0
  loss = 0

  #loss_logger = average_meter.AverageMeter()
  
  if torch.cuda.is_available():
    model = model.cuda()
  
  model.eval()

  for batch_idx, samples_batched in enumerate(dataloader): # samples_batched - word sentences in a batch of sentences (batch_size, 100)
      x_batched, y_batched = samples_batched  # x_batched (bz, 80), y_batched (bz, 80, 1)
      if torch.cuda.is_available():
          x_batched = x_batched.cuda()
          y_batched = y_batched.cuda()
      if dataset_name in ['MNISTNorm', 'SplitMNIST', 'MNIST']:
          x_batched = torch.squeeze(x_batched)
      elif dataset_name in ['SST', 'SSTPT', 'UDPOS', 'DBpedia', 'AG_NEWS'] and not isinstance(model, (LSTMWithEncoderDecoder, RNNWithEncoderDecoder)):
        x_batched = glove_embedding.get_batch_embedding(x_batched)
  
      batch_size = x_batched.size(0)
      with torch.no_grad():
          logits = model(x_batched)
      if dataset_name in ['UDPOS']:
          last_logits = logits.view(-1, logits.shape[-1])
          y_batched = y_batched.view(-1)
      else:
          last_logits = logits[:, -1, :] # size(batch_size, vocab_size)

      batch_loss = loss_func(last_logits, y_batched)
      loss += batch_loss
      total += batch_size

      if dataset_name in ['UDPOS']:
          selected = y_batched != dataloader.dataset.pad_idx
          y_batched = y_batched[selected]
          last_logits = last_logits[selected, :]

      prediction = torch.argmax(last_logits.detach(), dim=1) # size(batch_size)
      correct += torch.sum(y_batched == prediction)
  
  avg_loss = loss / total
  accuracy = 100.0 * correct / total
  return {
      'loss': avg_loss,
      'accuracy': accuracy,
  }

    

### Get all the weight in one nerual network
def get_weight(model):
  weights = np.concatenate([p.cpu().detach().numpy().ravel() for p in model.parameters()])
  return weights

def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--experiment_name', type=str, default='test')
  parser.add_argument('--model_name', type=str, default='FC')
  parser.add_argument('--dataset_name', type=str, default='MNIST')
  parser.add_argument('--result_path', type=str, default='result')

  parser.add_argument('--data_path', type=str, default='./data')
  parser.add_argument('--train_data_path', type=str, default='./data')
  parser.add_argument('--test_data_path', type=str, default='./data')
  parser.add_argument('--glove_path', type=str, default='./data/custom_datasets/glove.6B.100d.txt')
  parser.add_argument('--batch_size', type=int, default=64)

  parser.add_argument('--normalize', default=False, action='store_true')
  parser.add_argument('--nsplits', type=int, default=1,
                        help='Number of splits of the dataset')
  parser.add_argument('--split_index', type=int, default=1,
                        help='The current index of split dataset used!')
  parser.add_argument('--ds_scale_factor', type=float, default=1.0,
                        help='To understand effect of ds scaling')
  parser.add_argument('--alpha_h', type=float, default=None, nargs='+',
                      help='The weight for the hidden to hidden matrix costs')

  parser.add_argument('--grid_points', type=int, default=21,
                      help='number of points in the grid (default: 21)')
  parser.add_argument('--margin_left', type=float, default=0.2,
                      help='left margin (default: 0.2)')
  parser.add_argument('--margin_right', type=float, default=0.2,
                      help='right margin (default: 0.2)')
  parser.add_argument('--margin_bottom', type=float, default=0.2,
                      help='bottom margin (default: 0.)')
  parser.add_argument('--margin_top', type=float, default=0.2,
                      help='top margin (default: 0.2)')

  parser.add_argument('--input_dim', type=int, default=784)
  parser.add_argument('--hidden_dims', type=int, nargs='+', default=[])
  parser.add_argument('--output_dim', type=int, default=10)

  parser.add_argument('--evaluate', default=False, action='store_true')
  parser.add_argument('--resume', default=False, action='store_true')
  parser.add_argument('--init_start', type=str, default=None,
                      help='checkpoint to init start point (default: None)')
  parser.add_argument('--init_end', type=str, default=None,
                      help='checkpoint to init end point (default: None)')
  parser.add_argument('--fused_model_path', type=str, default=None,
                      help='checkpoint to the fused model')
  parser.add_argument('--permuted_model_1_path', type=str, default=None)
  parser.add_argument('--permuted_model_2_path', type=str, default=None)
  parser.add_argument('--random_initialization_plot_1', default=False, action='store_true')
  parser.add_argument('--random_initialization_plot_2', default=False, action='store_true')

  parser.add_argument('--no_cuda', default=False, action='store_true')
  parser.add_argument('--gpu_ids', type=str, default='0')
  parser.add_argument("--seed", default=24601, type=int)

  parser.add_argument('--hetero_special_digit', type=int, default=4,
                        help='Special digit for heterogeneous MNIST')
  parser.add_argument('--hetero_special_train', default=False, action='store_true',
                        help='If HeteroMNIST with special digit split is used for training.')
  parser.add_argument('--hetero_other_digits_train_split', default=0.9, type=float)
  parser.add_argument('--heterogeneous', default=False, action='store_true')

  parser.add_argument('--finetune_visualization', default=False, action='store_true')
  parser.add_argument('--finetuned_model_path', type=str, default=None,
                      help='checkpoint to the finetuned model') 

  parser.add_argument('--donot_use_embedding', default=False, action='store_true')
  parser.add_argument('--use_compact_embedding', default=False, action='store_true')
  parser.add_argument('--initialize_embedding', default=False, action='store_true')
  

  logging.basicConfig(level=logging.INFO)
  logger = logging.getLogger(__name__)
  
  args = parser.parse_args()

  model_name = args.experiment_name + '_' + args.model_name + '_' + args.dataset_name 
  output_path = os.path.join(args.result_path, model_name)
  # if args.random_initialization_plot_1:
  #   output_path = os.path.join(output_path, 'random_initialization_plot_1')
  # elif args.random_initialization_plot_2:
  #   output_path = os.path.join(output_path, 'random_initialization_plot_2')
  make_dirs(output_path)
  name = model_name
  logging.info("Generating grid plane for %s", name)


  os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
  args.gpu_id_list = [int(s) for s in args.gpu_ids.split(',')]
  args.cuda = not args.no_cuda and torch.cuda.is_available()
  logging.basicConfig(level=logging.INFO)

  ### Loading models
  model_1 = load_model(args.model_name, args.init_start)
  model_2 = load_model(args.model_name, args.init_end)
  fused_model = load_model(args.model_name, args.fused_model_path)
  permuted_model_1 = load_model(args.model_name, args.permuted_model_1_path)
  permuted_model_2 = load_model(args.model_name, args.permuted_model_2_path)

  ### Get vectors basises
  w = [get_weight(model_1), get_weight(permuted_model_2), get_weight(model_2)]
  
  logging.info("Weight space dimentionality: {}".format(w[0].shape[0]))

  ### Get the coordinates of each model on the plane
  model_1_w = get_weight(model_1)
  model_2_w = get_weight(model_2)
  fused_model_w = get_weight(fused_model)
  permuted_model_1_w = get_weight(permuted_model_1)
  permuted_model_2_w = get_weight(permuted_model_2)

  
  config = model_1.get_model_config()
  
  if args.dataset_name in ['SST', 'SSTPT', 'DBpedia', 'AG_NEWS']:
      glove_embedding = embedding.GloveEmbedding(args.glove_path)
  else:
      glove_embedding = None

  
  ### Generate orthonormal basises
  u = w[2] - w[0]
  dx = np.linalg.norm(u)
  u /= dx

  v = w[1] - w[0]
  v -= np.dot(u, v) * u
  dy = np.linalg.norm(v)
  v /= dy

  bend_coordinates = np.stack([get_xy(p, w[0], u, v) for p in w])
  model_1_coordinates = get_xy(model_1_w, w[0], u, v)
  model_2_coordinates = get_xy(model_2_w, w[0], u, v)
  fused_model_coordinates = get_xy(fused_model_w, w[0], u, v)
  permuted_model_1_coordinates = get_xy(permuted_model_1_w, w[0], u, v)
  permuted_model_2_coordinates = get_xy(permuted_model_2_w, w[0], u, v)

  logging.info('The coordinates of model 1 on the plane: {}'.format(model_1_coordinates))
  logging.info('The coordinates of model 2 on the plane {}'.format(model_2_coordinates))
  logging.info('The coordinates of the fused model on the plane: {}'.format(fused_model_coordinates))
  logging.info('The coordinates of permuted model 1 on the plane {}'.format(permuted_model_1_coordinates))
  logging.info('The coordinates of permuted model 2 on the plane {}'.format(permuted_model_2_coordinates))

  ### Generate the grid plane
  if args.model_name in ['RNN', 'LSTM', 'rnn', 'lstm']:
    trainloader, valoader, testloader = train_rnn.get_dataloaders(args)

    logging.info('Test accuracy of model 1:{}'.format(test_rnn(testloader, model_1, args.dataset_name, glove_embedding)))
    logging.info('Test accuracy of model 2:{}'.format(test_rnn(testloader, model_2, args.dataset_name, glove_embedding)))
    logging.info('Test accuracy of fused model:{}'.format(test_rnn(testloader, fused_model, args.dataset_name, glove_embedding)))
    logging.info('Test accuracy of permuted model 1:{}'.format(test_rnn(testloader, permuted_model_1, args.dataset_name, glove_embedding)))
    logging.info('Test accuracy of permuted model 2:{}'.format(test_rnn(testloader, permuted_model_2, args.dataset_name, glove_embedding)))
  else:
    trainloader, valoader, testloader = train_models.get_dataloaders(args)
  
    logging.info('Test accuracy of model 1:{}'.format(test(testloader, model_1, args.dataset_name)))
    logging.info('Test accuracy of model 2:{}'.format(test(testloader, model_2, args.dataset_name)))
    logging.info('Test accuracy of fused model:{}'.format(test(testloader, fused_model, args.dataset_name)))
    logging.info('Test accuracy of permuted model 1:{}'.format(test(testloader, permuted_model_1, args.dataset_name)))
    logging.info('Test accuracy of permuted model 2:{}'.format(test(testloader, permuted_model_2, args.dataset_name)))

  G = args.grid_points
  alphas = np.linspace(0.0 - args.margin_left, 1.0 + args.margin_right, G)
  betas = np.linspace(0.0 - args.margin_bottom, 1.0 + args.margin_top, G)

  tr_loss = np.zeros((G, G))
  tr_acc = np.zeros((G, G))
  tr_err = np.zeros((G, G))

  te_loss = np.zeros((G, G))
  te_acc = np.zeros((G, G))
  te_err = np.zeros((G, G))

  grid = np.zeros((G, G, 2))
  
  if args.model_name in ['RNN', 'LSTM', 'rnn', 'lstm']:
      base_model = fuse_rnn_models.get_model(args.model_name, input_dim=28, config=config, encoder=False)
  else:
      base_model = get_model(args.model_name, config)
  if torch.cuda.is_available():
    base_model.cuda()

  columns = ['X', 'Y', 'Train loss', 'Train error (%)', 'Test error (%)']
  logging.info("Begin to generate grid plane.")

  for i, alpha in enumerate(alphas):
    for j, beta in enumerate(betas):
      # Generate the weights of the neural networks at point (i, j)
      p = w[0] + alpha * dx * u + beta * dy * v

      offset = 0
      for parameter in base_model.parameters():
        size = np.prod(parameter.size())
        value = p[offset:offset+size].reshape(parameter.size())
        parameter.data.copy_(torch.from_numpy(value))
        offset += size

      if args.model_name in ['RNN', 'LSTM', 'rnn', 'lstm']:
          tr_res = test_rnn(trainloader, base_model, args.dataset_name, glove_embedding)
          te_res = test_rnn(testloader, base_model, args.dataset_name, glove_embedding)
      else:
          tr_res = test(trainloader, base_model, args.model_name)
          te_res = test(testloader, base_model, args.model_name)

      tr_loss_v, tr_acc_v = tr_res['loss'], tr_res['accuracy']
      te_loss_v, te_acc_v = te_res['loss'], te_res['accuracy']

      c = get_xy(p, w[0], u, v)
      grid[i, j] = [alpha * dx, beta * dy]

      tr_loss[i, j] = tr_loss_v
      tr_acc[i, j] = tr_acc_v
      tr_err[i, j] = 100.0 - tr_acc[i, j]

      te_loss[i, j] = te_loss_v
      te_acc[i, j] = te_acc_v
      te_err[i, j] = 100.0 - te_acc[i, j]

      values = [
              grid[i, j, 0], grid[i, j, 1], tr_loss[i, j], tr_err[i, j], te_err[i, j]
      ]

      table = tabulate.tabulate([values], columns, tablefmt = 'simple', floatfmt='10.4f')
      if j == 0:
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
      else:
        table = table.split('\n')[2]
      print(table)
  
  np.savez(
      os.path.join(output_path, 'plane.npz'),
      bend_coordinates = bend_coordinates,
      fused_model_coordinates = fused_model_coordinates,
      permuted_model_1_coordinates = permuted_model_1_coordinates,
      permuted_model_2_coordinates = permuted_model_2_coordinates,
      alphas = alphas,
      betas = betas,
      grid = grid,
      tr_loss = tr_loss,
      tr_acc = tr_acc,
      tr_err = tr_err,
      te_loss = te_loss,
      te_acc = te_acc,
      te_err = te_err
  )

if __name__ == "__main__":
  main()