import numpy as np
import torch
import torch.optim
import os

import torch.nn as nn
import torch.nn.functional as F
from methods import backbone
from methods.backbone import model_dict
from data.datamgr import SimpleDataManager, SetDataManager
from methods.baselinetrain import BaselineTrain
from methods.gnnnet_mlp import GnnNet as GnnNet
from methods.gnn_mlp import Mlp
from options import parse_args


def train(base_loader, val_loader, model_1shot, model_2shot, model_mlp, start_epoch, stop_epoch):
  # get optimizer and checkpoint path
  len_1shot = len(model_1shot)
  len_mlp = len(model_mlp)
  len_mlp_dir = len(params.mlps_1_to_bi2_mlp_dir)
  optimizer = []
  for k in range(len_mlp):
    optimizer.append(torch.optim.Adam(model_mlp[k].parameters()))
  for k in range(len_mlp_dir):
    if not os.path.isdir(params.mlps_1_to_bi2_mlp_dir[k]):
      os.makedirs(params.mlps_1_to_bi2_mlp_dir[k])
  max_rise = [0 for j in range(len_mlp_dir)]
  total_it = 0
  # start
  for epoch in range(start_epoch, stop_epoch):
    for k in range(len_1shot):
      model_1shot[k].eval()
    model_2shot.eval()
    for k in range(len_mlp):
      model_mlp[k].train()
    total_it = train_loop(epoch, base_loader, optimizer, total_it, model_1shot, model_2shot, model_mlp)

    for k in range(len_mlp):
      model_mlp[k].eval()
    acc_few, acc_mlp = test_loop(val_loader, model_1shot, model_mlp)

    for k in range(len_mlp_dir):
      if acc_mlp[k] - acc_few[k] > max_rise[k]:
        max_rise[k] = acc_mlp[k] - acc_few[k]
        print('best rise acc mlp' + str(k) + '! save... ', max_rise[k])
        outfile = os.path.join(params.mlps_1_to_bi2_mlp_dir[k], 'mlp{:d}.tar'.format(k+1))
        torch.save({'epoch': epoch, 'state': model_mlp[k].state_dict()}, outfile)
      else:
        print("GG! best mlp" + str(k), " ", max_rise[k])

    if ((epoch + 1) % params.save_freq == 0) or (epoch == stop_epoch - 1):
      for k in range(len_mlp):
        outfile = os.path.join(params.mlps_1_to_bi2_mlp_dir[k], '{:d}.tar'.format(epoch))
        torch.save({'epoch': epoch, 'state': model_mlp[k].state_dict()}, outfile)

  return 0

def train_loop(epoch, train_loader, optimizer, total_it, model_1shot, model_2shot, model_mlp):
    print_freq = len(train_loader) // 10
    len_mlp = len(model_mlp)
    len_1shot = len(model_1shot)
    avg_loss = [0 for j in range(len_mlp)]
    n_shot_few = params.n_shot_much - params.n_shot_few
    for i, (x, _) in enumerate(train_loader):
      y_query = torch.from_numpy(np.repeat(range(5), 16))
      y_query = y_query.cuda()
      with torch.no_grad():
        f_1shot = [0 for j in range(len_1shot)]
        for k in range(len_1shot):
          scores_1shot, f_1shot[k], loss_1shot = model_1shot[k].set_forward_loss(x[:, n_shot_few:])
        scores_2shot, f_2shot, loss_2shot = model_2shot.set_forward_loss(x)

      scores_mlp = [0 for j in range(len_mlp+1)]
      f_mlp = [0 for j in range(len_mlp)]

      optimizer[0].zero_grad()
      with torch.no_grad():
        f_mlp1_no_grad, scores_mlp1_no_grad = model_mlp[1](f_1shot[1])
      f_mlp[0], scores_mlp[0] = model_mlp[0](f_1shot[0])

      loss_h00 = loss_class_fn(scores_mlp[0], y_query)
      loss_k01 = kl(scores_mlp[0], scores_mlp1_no_grad)
      loss_f00 = loss_mse_fn(f_mlp[0], f_2shot)

      loss0 = loss_f00 + loss_h00 + loss_k01
      loss0.backward()
      optimizer[0].step()
      avg_loss[0] = avg_loss[0] + loss0.item()

      optimizer[1].zero_grad()
      with torch.no_grad():
        f_mlp0_no_grad, scores_mlp0_no_grad = model_mlp[0](f_1shot[0])
      f_mlp[1], scores_mlp[1] = model_mlp[1](f_1shot[1])

      loss_h11 = loss_class_fn(scores_mlp[1], y_query)
      loss_k10 = kl(scores_mlp[1], scores_mlp0_no_grad)
      loss_f11 = loss_mse_fn(f_mlp[1], f_2shot)

      loss1 = loss_f11 + loss_h11 + loss_k10
      loss1.backward()
      optimizer[1].step()
      avg_loss[1] = avg_loss[1] + loss1.item()

      if (i + 1) % print_freq == 0:
        print(
          'E{:d}|B {:d}/{:d}|L1 {:f}|L2 {:f}'.format(
            epoch, i + 1, len(train_loader),
            avg_loss[0] / float(i + 1), avg_loss[1] / float(i + 1))
        )
      total_it += 1
    return total_it

def test_loop(test_loader, model_1shot,model_mlp):
    len_1shot = len(model_1shot)
    len_mlp = len(model_mlp)
    acc_all_few_shot = [[] for j in range(len_1shot)]
    acc_mean_few_shot = [[] for j in range(len_1shot)]
    acc_all_mlp = [[] for j in range(len_mlp)]
    acc_mean_mlp = [0 for j in range(len_mlp)]
    iter_num = len(test_loader)

    for i, (x, _) in enumerate(test_loader):
      scores_1shot = [0 for j in range(len_1shot)]
      f_1shot = [0 for j in range(len_1shot)]
      scores_mlp = [0 for j in range(len_mlp)]

      for k in range(len_1shot):
        scores_1shot[k], f_1shot[k], loss = model_1shot[k].set_forward_loss(x) 
        correct_this, count_this = correct(scores_1shot[k])
        acc_all_few_shot[k].append(correct_this / count_this * 100)
      for k in range(len_mlp):
        f_mlp, scores_mlp[k] = model_mlp[k](f_1shot[k]) 
        correct_this, count_this = correct(scores_mlp[k])
        acc_all_mlp[k].append(correct_this / count_this * 100)

    for k in range(len_1shot):
      acc_all_few_shot[k]  = np.asarray(acc_all_few_shot[k])
      acc_mean_few_shot[k] = np.mean(acc_all_few_shot[k])
      acc_std_few_shot  = np.std(acc_all_few_shot[k])
      print('--- %d Test few_shot Acc = %4.2f%% +- %4.2f%% ---' %(iter_num,  acc_mean_few_shot[k], 1.96* acc_std_few_shot/np.sqrt(iter_num)))

    for k in range(len_mlp):
      acc_all_mlp[k] = np.asarray(acc_all_mlp[k])
      acc_mean_mlp[k] = np.mean(acc_all_mlp[k])
      acc_std_mlp = np.std(acc_all_mlp[k])
      print('--- %d Test mlp%d Acc = %4.2f%% +- %4.2f%% ---' % (iter_num, k, acc_mean_mlp[k], 1.96 * acc_std_mlp / np.sqrt(iter_num)))

    return acc_mean_few_shot, acc_mean_mlp

def correct(scores):
  y_query = np.repeat(range(5), 16)
  topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
  topk_ind = topk_labels.cpu().numpy()
  top1_correct = np.sum(topk_ind[:, 0] == y_query)
  return float(top1_correct), len(y_query)


def kl(y_s, y_t):
  loss = loss_kl(F.log_softmax(y_s, dim=1), F.softmax(y_t, dim=1))
  return loss


if __name__=='__main__':
  # set numpy random seed
  np.random.seed(10)

  # parser argument
  params = parse_args('train')
  params.layers = [1, 1]
  print(params)

  # output and tensorboard dir
  params.tf_dir = '%s/log/%s'%(params.save_dir, params.name)
  params.model_dir = '%s/checkpoints/%s'%(params.save_dir, params.name)
  if not os.path.isdir(params.model_dir):
    os.makedirs(params.model_dir)

  params.mlps_1_to_bi2_dir = os.path.join(params.model_dir, 'mlps')
  if not os.path.isdir(params.mlps_1_to_bi2_dir):
    os.makedirs(params.mlps_1_to_bi2_dir)
  params.mlps_1_to_bi2_mlp_name = ['mlp1', 'mlp2']
  params.mlps_1_to_bi2_mlp_dir = []
  for k in range(len(params.mlps_1_to_bi2_mlp_name)):
    params.mlps_1_to_bi2_mlp_dir.append(os.path.join(params.mlps_1_to_bi2_dir, params.mlps_1_to_bi2_mlp_name[k]))
    if not os.path.isdir(params.mlps_1_to_bi2_mlp_dir[k]):
      os.makedirs(params.mlps_1_to_bi2_mlp_dir[k])

  # dataloader
  print('\n--- prepare dataloader ---')
  print('{:d}shot'.format(params.n_shot_few))

  print('  train with single seen domain {}'.format(params.dataset))
  base_file = os.path.join(params.data_dir, params.dataset, params.base + '.json')
  val_file = os.path.join(params.data_dir, params.dataset, 'val.json')
  print(base_file)
  print(val_file)

  # model
  print('\n--- build model ---')
  image_size = 224

  if params.method in ['baseline']:
    print('  pre-training the feature encoder {} using method {}'.format(params.model, params.method))
    base_datamgr    = SimpleDataManager(image_size, batch_size=16)
    base_loader     = base_datamgr.get_data_loader( base_file , aug=params.train_aug )
    val_datamgr     = SimpleDataManager(image_size, batch_size=64)
    val_loader      = val_datamgr.get_data_loader(val_file, aug=False)
    model           = BaselineTrain(model_dict[params.model], params.num_classes, tf_path=params.tf_dir)

  elif params.method in ['gnnnet']:
    print('  baseline training the model {} with feature encoder {}'.format(params.method, params.model))
    #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    n_query = max(1, int(16 * params.test_n_way/params.train_n_way))

    train_2shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot_much)
    train_1shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot_few)
    base_datamgr            = SetDataManager(image_size, n_query = n_query, n_eposide=427, **train_2shot_params)
    base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )

    test_1shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot_few)
    val_datamgr             = SetDataManager(image_size, n_query = n_query, n_eposide=107, **test_1shot_params)
    val_loader              = val_datamgr.get_data_loader(val_file, aug = False)

    model_1shot = []
    for k in range(2):
      model_1shot.append(GnnNet(model_dict[params.model], n_layer=params.layers, wdrop=params.wdrop, ft=params.maml,
                             rest=params.wrest,
                             fin_fc=params.fin_fc, mlp=False, tf_path=params.tf_dir, **train_1shot_params))
      model_1shot[k] = model_1shot[k].cuda()
    model_2shot = GnnNet(model_dict[params.model], n_layer=params.layers, wdrop=params.wdrop, ft=params.maml,
                             rest=params.wrest,
                             fin_fc=params.fin_fc, mlp=False, tf_path=params.tf_dir, **train_2shot_params)
    model_2shot = model_2shot.cuda()
  else:
    raise ValueError('Unknown method')

  model_mlp = []
  for k in range(2):
    model_mlp.append(Mlp(input_features=458, n_way=5, nf=96, ratio=[5, 5, 5, 5], drop=False, rest=False, ft=False))
    model_mlp[k] = model_mlp[k].cuda()

  start_epoch = params.start_epoch
  stop_epoch = params.stop_epoch

  loss_class_fn = nn.CrossEntropyLoss()
  loss_mse_fn = nn.MSELoss()
  loss_kl = nn.KLDivLoss(reduction='batchmean')

  print('load model')
  load_model_shots_file = [params.model_few_shot_1, params.model_few_shot_2, params.model_much_shot_1]
  for k in range(2):
    temp_dir = os.path.join(params.model_dir, load_model_shots_file[k] + '.tar')
    tmp_file = torch.load(temp_dir)
    model_1shot[k].load_state_dict(tmp_file['state'])

  temp_dir = os.path.join(params.model_dir, load_model_shots_file[2] + '.tar')
  tmp_file = torch.load(temp_dir)
  model_2shot.load_state_dict(tmp_file['state'])
  # training
  print('\n--- start the training ---')
  model = train(base_loader, val_loader, model_1shot, model_2shot, model_mlp, start_epoch, stop_epoch)
