import torch
import os
import h5py

# from methods import backbone
from methods.backbone import model_dict
from data.datamgr import SimpleDataManager
from options import parse_args, get_best_file, get_assigned_file
from methods.gnnnet_mlp import GnnNet
from methods.gnn_mlp import Mlp
import data.feature_loader as feat_loader
import random
import numpy as np


# extract and save image features
def save_features_cat(model, data_loader, featurefile):
  f = h5py.File(featurefile, 'w')
  max_count = len(data_loader)*data_loader.batch_size
  all_labels = f.create_dataset('all_labels',(max_count,), dtype='i')
  all_feats=None
  count=0
  for i, (x,y) in enumerate(data_loader):
    if (i % 10) == 0:
      print('    {:d}/{:d}'.format(i, len(data_loader)))
    x = x.cuda()
    feats = torch.cat([model[0](x), model[1](x)], dim=1)
    if all_feats is None:
      all_feats = f.create_dataset('all_feats', [max_count] + list( feats.size()[1:]) , dtype='f')
    all_feats[count:count+feats.size(0)] = feats.data.cpu().numpy()
    all_labels[count:count+feats.size(0)] = y.cpu().numpy()
    count = count + feats.size(0)

  count_var = f.create_dataset('count', (1,), dtype='i')
  count_var[0] = count
  f.close()


# evaluate using features
def feature_evaluation_cat(cl_data_file, model, mlp, n_way = 5, n_support = 5, n_query = 15):
  class_list = cl_data_file.keys()
  select_class = random.sample(class_list,n_way)
  z_all  = []
  for cl in select_class:
    img_feat = cl_data_file[cl]
    perm_ids = np.random.permutation(len(img_feat)).tolist()
    z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )
  z_all = torch.from_numpy(np.array(z_all) )

  scores=[0, 0]
  f = [0,0]
  for j in range(2):
    model[j].n_query = n_query
    model[j].n_query = n_query
    _, f[j] = model[j].set_forward(z_all[:, :, 512*j:512*(j+1)], is_feature = True)
    _, scores[j] = mlp[j](f[j])
  scores_mlp = 0.5 * (scores[0] + scores[1])
  pred = scores_mlp.data.cpu().numpy().argmax(axis = 1)
  y = np.repeat(range( n_way ), n_query )
  acc = np.mean(pred == y)*100
  return acc


if __name__ == '__main__':

  # parse argument
  params = parse_args('test')
  params.layers = [1, 1]
  print(params)
  print('Testing! {} shots on {} dataset with {} ({})'.format(params.n_shot, params.dataset, params.name, params.method))
  remove_featurefile = True
  print('\nStage 1: saving features')
  print('  build dataset')
  image_size = 224
  split = params.split
  loadfile = os.path.join(params.data_dir, params.dataset, split + '.json')
  print(loadfile)
  datamgr = SimpleDataManager(image_size, batch_size=64)
  data_loader = datamgr.get_data_loader(loadfile, aug=False)

  print('  build feature encoder')
  checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
  model = []
  load_model_shots_file = [params.save_gnn_1, params.save_gnn_2]
  for j in range(2):
    model.append(model_dict[params.model]())
    model[j] = model[j].cuda()
    modelfile = os.path.join(checkpoint_dir, load_model_shots_file[j] + '.tar')
    tmp = torch.load(modelfile)
    try:
      state = tmp['state']
    except KeyError:
      state = tmp['model_state']
    except:
      raise
    state_keys = list(state.keys())
    for i, key in enumerate(state_keys):
      if "feature." in key:
        newkey = key.replace("feature.", "")
        state[newkey] = state.pop(key)
      else:
        state.pop(key)
    model[j].load_state_dict(state)
    model[j].eval()

  print('  extract and save features...')
  featurefile = os.path.join(checkpoint_dir, params.split + "_" + params.dataset + ".hdf5")
  dirname = os.path.dirname(featurefile)
  if not os.path.isdir(dirname):
    os.makedirs(dirname)
  save_features_cat(model, data_loader, featurefile)

  print('\nStage 2: evaluate')
  few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
  print('  build metric-based model')
  model = []
  load_model_shots_file = [params.save_gnn_1, params.save_gnn_2]
  if params.method == 'gnnnet':
    for j in range(2):
      model.append(GnnNet(model_dict[params.model], n_layer=params.layers, wdrop=params.wdrop, ft=params.maml,
                     rest=params.wrest, fin_fc=params.fin_fc, mlp=False, **few_shot_params))
      model[j] = model[j].cuda()
      model[j].eval()
      modelfile = os.path.join(checkpoint_dir, load_model_shots_file[j] + '.tar')
      if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
          model[j].load_state_dict(tmp['state'])
        except RuntimeError:
          print('warning! RuntimeError when load_state_dict()!')
          model[j].load_state_dict(tmp['state'], strict=False)
        except KeyError:
          for k in tmp['model_state']:  ##### revise latter
            if 'running' in k:
              tmp['model_state'][k] = tmp['model_state'][k].squeeze()
          model[j].load_state_dict(tmp['model_state'], strict=False)
        except:
          raise
  else:
    raise ValueError('Unknown method')

  checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
  load_model_mlp_file = [params.save_mlp_1, params.save_mlp_2]
  mlp = []
  for j in range(2):
    mlp.append(Mlp(input_features=458, n_way=5, nf=96, ratio=[5, 5, 5, 5], drop=False, rest=False, ft=False))
    mlp[j] = mlp[j].cuda()
    mlp[j].eval()
    mlp_file = os.path.join(checkpoint_dir, load_model_mlp_file[j] + '.tar')
    tmp_mlp = torch.load(mlp_file)
    mlp[j].load_state_dict(tmp_mlp['state'])

  print('  load saved feature file')
  cl_data_file = feat_loader.init_loader(featurefile)

  print('  evaluate')
  acc_all = []
  iter_num = 10000
  for i in range(iter_num):
    acc = feature_evaluation_cat(cl_data_file, model, mlp, n_query=15, **few_shot_params)
    acc_all.append(acc)

  print('  get statics')
  acc_all = np.asarray(acc_all)
  acc_mean = np.mean(acc_all)
  acc_std = np.std(acc_all)
  print('  %d test iterations: Acc_m10 = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))

  if remove_featurefile:
    os.remove(featurefile)