from dataset.nas_data import NASData
from dataset.darts_data import DARTSData

from model.generators.erdos_renyi import ErdosRenyi
from model.generators.gran import GRAN
from model.generators.darts import DartsGenerator
from model.evaluators.nas_bench import NASBench301
from model.evaluators.uniform import Uniform

from utils.nas_helper import fmt_graphs, genotype_to_adjs
from utils.darts.utils import save, load
from utils.darts.genotypes import Genotype

import torch
import torch.optim as optim

import numpy as np

from easydict import EasyDict as edict

import random

def test():
  # actual params we want to change in the test
  cfg_search_space = 'darts' # ['darts', 'smalldarts']
  cfg_weights_save_file = "clogs/nas005weights1.pt"
  cfg_train_from_scratch = True
  cfg_load_from_logs = True
  cfg_logs_path = "clogs/nas005c100-50etr.txt"

  if cfg_search_space == 'darts':
    cfg_generator_cls = 'DartsGenerator'
    cfg_data_cls = "DARTSData"
  elif cfg_search_space == 'smalldarts':
    cfg_generator_cls = 'GRAN'
    cfg_data_cls = "NASData"
  else:
    raise "unsupported search space"

  # general config
  config = edict({
  'dataset': {
    'CIFAR_path': 'data/',
    'NB101_path': 'data/nasbench_only108.tfrecord',
    'NB301_path': 'data/nb_models/xgb_v1.0',
    'cls': cfg_data_cls,
  },
 'debug': True,
 'seed': 2,
 'evaluator': {'cls': 'Uniform'},
 'experimental': {
   'adj_95pct': False,
   'lab_95pct': False,
   'nb301_noise': False,
   'temperature': 1,
   'use_temperature': 'none'
  },
 'generator': {'block_size': 1,
               'cls': cfg_generator_cls,
               'device': 'cpu',
               'dimension_reduce': True,
               'edge_weight': 1.0,
               'embedding_dim': 24,
               'erdos_renyi_p': 0.25,
               'has_attention': True,
               'hidden_dim': 24,
               'max_num_nodes': 100,
               'name': 'GRAN',
               'num_GNN_layers': 7,
               'num_GNN_prop': 1,
               'num_canonical_order': 1,
               'num_edge_labels': 0,
               'num_mix_component': 10,
               'num_node_labels': 0,
               'sample_stride': 1,
               'search_space': cfg_search_space},
 'nas': {'baseline': False,
         'ewma_alpha': 0.5,
         'explore_method': 'three',
         'explore_p': 0.1,
         'keep_top': 100,
         'max_nas_iterations': 20000000,
         'max_oracle_evaluations': 300,
         'reward': 'cdf',
         'sample_batch_size': 8},
 'gpus': [0],
 'train': {'batch_size': 5,
           'display_iter': 10,
           'is_resume': False,
           'lr': 0.01,
           'lr_decay': 0.3,
           'lr_decay_epoch': [100000000],
           'max_epoch': 150,
           'momentum': 0.9,
           'num_workers': 0,
           'optimizer': 'Adam',
           'print_every': 5,
           'shuffle': True,
           'wd': 0.0},
 'use_gpu': False})

  random.seed(config.seed * 2)
  torch.manual_seed(config.seed * 3)
  np.random.seed(config.seed * 4)
  gran = eval(config.generator.cls)(config)
  er = ErdosRenyi(config)

  if cfg_train_from_scratch:


    if cfg_load_from_logs:
      nb301 = None
      dataset = load_from_logs(config, cfg_logs_path)
      rewt = dataset.rewards()
    else:
      nb301 = NASBench301(config)
      dataset = eval(config.dataset.cls)(config)
      samples = er.sample(100)
      rewards, _ = nb301.estimate(samples)
      dataset.append(list(zip(samples, rewards)))
      rewt = torch.tensor(rewards)
    print(f"mean={rewt.mean()}, std={rewt.std()}, max={rewt.max()}")

    train_loader = torch.utils.data.DataLoader(
      dataset,
      batch_size=config.train.batch_size,
      shuffle=config.train.shuffle,
      num_workers=config.train.num_workers,
      collate_fn=dataset.collate_fn,
      drop_last=False
    )

    params = filter(lambda p: p.requires_grad, gran.parameters())

    if config.train.optimizer == 'SGD':
      optimizer = optim.SGD(
        params,
        lr=config.train.lr,
        momentum=config.train.momentum,
        weight_decay=config.train.wd
      )
    elif config.train.optimizer == 'Adam':
      optimizer = optim.Adam(
        params, lr=config.train.lr, weight_decay=config.train.wd
      )
    else:
      raise ValueError("Non-supported optimizer!")

    lr_scheduler = optim.lr_scheduler.MultiStepLR(
      optimizer,
      milestones=config.train.lr_decay_epoch,
      gamma=config.train.lr_decay
    )

    for epoch in range(config.train.max_epoch):
      gran.train()

      train_loss = 0.
      for batch in train_loader:
        optimizer.zero_grad()

        loss = gran(batch)
        loss.backward()

        optimizer.step()

        train_loss += loss.item()

      if epoch % config.train.print_every == config.train.print_every - 1:
        print(f"Epoch={epoch}, Train loss={train_loss/len(train_loader)}")

      lr_scheduler.step()

    save(gran, cfg_weights_save_file)
  else:
    load(gran, cfg_weights_save_file)
    nb301 = None
    # load(gran, "logs/debug/2021-01-30-19-52-55_darts_nb301_300_acc_baseline_keeptop30_1234_effectiveness_417/gen_snapshot_iter41.pt")
    # load(gran, "logs/darts_ss/2021-02-13-05-11-25_debug_darts_ss_reduce_cell_12345_30482/gen_snapshot_iter38.pt")
    # load(gran, "logs/darts_ss/2021-02-13-06-47-40_debug_darts_ss_reduce_cell_12346_30527/gen_snapshot_iter38.pt")
  gran.eval()

  # *** beam samples ***
  # samples = gran.sample(10, sample_method="beam")
  # print("BEAM Samples")
  # print(fmt_graphs(config, samples, fmt="adj"))

  # if nb301 is None:
    # nb301 = NASBench301(config)
  # rewards, _ = nb301.estimate(samples)
  # beam_rewt = torch.tensor(rewards)
  # print("rewards: ", rewards)

  # *** print log likelihoods ***
  # dset2 = eval(config.dataset.cls)(config)
  # dset2.append(list(zip(samples, rewards)))
  # data_loader = torch.utils.data.DataLoader(
    # dset2,
    # batch_size=1,
    # shuffle=False,
    # num_workers=config.train.num_workers,
    # collate_fn=dset2.collate_fn,
    # drop_last=False
  # )
  # for batch in data_loader:
    # print(gran(batch))

  # *** print remote arch likelihoods ***
  from remotearchs import c100_50etr_uniform_top8
  top8 = c100_50etr_uniform_top8
  dset2 = eval(config.dataset.cls)(config)
  dset2.append([(genotype_to_adjs(arch), 1.) for arch in top8])
  data_loader = torch.utils.data.DataLoader(
    dset2,
    batch_size=1,
    shuffle=False,
    num_workers=config.train.num_workers,
    collate_fn=dset2.collate_fn,
    drop_last=False
  )
  for batch in data_loader:
    batch[0]['return_neg_log_prob'] = True
    batch[1]['return_neg_log_prob'] = True
    nll, prob = gran(batch)
    print(nll)
    print(prob, torch.exp(-prob))


  samples = gran.sample(10)
  print("Random Samples")
  # print(fmt_graphs(config, samples, fmt="adj"))
  # *** print just sampled likelihoods ***
  # dset2 = eval(config.dataset.cls)(config)
  # dset2.append([(arch, 1.) for arch in samples])
  # data_loader = torch.utils.data.DataLoader(
    # dset2,
    # batch_size=1,
    # shuffle=False,
    # num_workers=config.train.num_workers,
    # collate_fn=dset2.collate_fn,
    # drop_last=False
  # )
  # for batch in data_loader:
    # batch[0]['return_neg_log_prob'] = True
    # batch[1]['return_neg_log_prob'] = True
    # nll, prob = gran(batch)
    # print(nll)
    # print(prob, torch.exp(-prob))


  # rewards, _ = nb301.estimate(samples)
  # print(rewards)

  # print(f"beam: mean={beam_rewt.mean()}, std={beam_rewt.std()}, max={beam_rewt.max()}")

  # rewt = torch.tensor(rewards)
  # print(f"rand: mean={rewt.mean()}, std={rewt.std()}, max={rewt.max()}")


def load_from_logs(config, path):
  dset = eval(config.dataset.cls)(config)
  with open(path, "r") as f:
    state = 0
    for line in f:
      if "Performance Estimate:" in line:
        state = 1
      elif state == 1:
        genotype = eval(line)
        state += 1
      elif state == 2:
        acc = float(line.split(" ")[2].split("=")[1])
        dset.append([(genotype_to_adjs(genotype), acc)])
        state = 0
  return dset


