# This file is for experimental usage
import os, sys, torch, random
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch.nn as nn

from utils  import obtain_accuracy
from models import CellStructure
from log_utils import time_string

def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
  weights = deepcopy(model.state_dict())
  model.train(cal_mode)
  with torch.no_grad():
    logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
    archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
    probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
    loader_iter = iter(xloader)
    random.seed(seed)
    random.shuffle(archs)
    for idx, arch in enumerate(archs):
      arch_index = api.query_index_by_arch( arch )
      metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
      gt_accs_10_valid.append( metrics['valid-accuracy'] )
      metrics = api.get_more_info(arch_index, 'cifar10', None, False, False)
      gt_accs_10_test.append( metrics['test-accuracy'] )
      select_logits = []
      for i, node_info in enumerate(arch.nodes):
        for op, xin in node_info:
          node_str = '{:}<-{:}'.format(i+1, xin)
          op_index = model.op_names.index(op)
          select_logits.append( logits[model.edge2index[node_str], op_index] )
      cur_prob = sum(select_logits).item()
      probs.append( cur_prob )
    cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1]
    cor_prob_test  = np.corrcoef(probs, gt_accs_10_test )[0,1]
    print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test))
      
    for idx, arch in enumerate(archs):
      model.set_cal_mode('dynamic', arch)
      try:
        inputs, targets = next(loader_iter)
      except:
        loader_iter = iter(xloader)
        inputs, targets = next(loader_iter)
      _, logits = model(inputs.cuda())
      _, preds  = torch.max(logits, dim=-1)
      correct = (preds == targets.cuda() ).float()
      accuracies.append( correct.mean().item() )
      if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
        cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1]
        cor_accs_test  = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1]
        print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test))
  model.load_state_dict(weights)
  return archs, probs, accuracies
