import os

import logging
import numpy as np
import torch
from pprint import pformat

from utils.darts.genotypes import Genotype
from utils.darts.model import NetworkCIFAR
from utils.darts.utils import count_parameters_in_MB
from utils.arg_helper import parse_arguments, get_config
from utils.logger import setup_logging
from utils.nas_helper import adj_to_genotype
from utils.sgd_helper import train

def main():
#  adj = torch.tensor([[0., 0., 3., 0., 0., 0.],
#                      [0., 0., 3., 4., 4., 5.],
#                      [0., 0., 0., 4., 5., 5.],
#                      [0., 0., 0., 0., 0., 0.],
#                      [0., 0., 0., 0., 0., 0.],
#                      [0., 0., 0., 0., 0., 0.]])
#  adj = torch.tensor([[0., 0., 5., 3., 0., 5.],
#                      [0., 0., 5., 6., 4., 3.],
#                      [0., 0., 0., 0., 0., 0.],
#                      [0., 0., 0., 0., 6., 0.],
#                      [0., 0., 0., 0., 0., 0.],
#                      [0., 0., 0., 0., 0., 0.]])
  # adj = torch.tensor([[0., 0., 3., 3., 4., 4.], # c100 best val
  #       [0., 0., 3., 0., 4., 3.],
  #       [0., 0., 0., 5., 0., 0.],
  #       [0., 0., 0., 0., 0., 0.],
  #       [0., 0., 0., 0., 0., 0.],
  #       [0., 0., 0., 0., 0., 0.]])
#  adj = torch.tensor([[0., 0., 3., 3., 3., 4.],
#        [0., 0., 4., 0., 4., 3.],
#        [0., 0., 0., 6., 0., 0.],
#        [0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 0., 0., 0.]])
  # adj = torch.tensor([[0., 0., 1., 3., 3., 0.], #C10 small darts it_19 97.33 (2.67)
  #       [0., 0., 5., 3., 0., 4.],
  #       [0., 0., 0., 0., 0., 0.],
  #       [0., 0., 0., 0., 5., 0.],
  #       [0., 0., 0., 0., 0., 4.],
  #       [0., 0., 0., 0., 0., 0.]])
  adj = torch.tensor([[0., 0., 1., 3., 3., 0.], #C100 small darts it_20 
        [0., 0., 4., 4., 3., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 5.],
        [0., 0., 0., 0., 0., 4.],
        [0., 0., 0., 0., 0., 0.]])
  genotype = adj_to_genotype(adj)

  args = parse_arguments()
  config = get_config(args.config_file, is_test=args.test)
  np.random.seed(config.seed)
  torch.manual_seed(config.seed)
  torch.cuda.manual_seed_all(config.seed)

  log_file = os.path.join(config.save_dir, "log_exp_{}.txt".format(config.run_id))
  logger = setup_logging(args.log_level, log_file)
  logger.info("Writing log file to {}".format(log_file))
  logger.info(f"Preparing training, genotype=\n{genotype}")
  logger.info("Config =")
  print(">" * 80)
  logger.info(pformat(config))
  print("<" * 80)

  logger.info(genotype)

  num_classes = 10 if config.oracle.dataset == "CIFAR10" else 100
  print(f"num_classes {num_classes}")
  for c in range(30, 60):
    model = NetworkCIFAR(c, num_classes, 20, config.oracle.auxiliary, genotype)
    if count_parameters_in_MB(model) > 3.5:
      break
  model = NetworkCIFAR(c - 1, num_classes, 20, config.oracle.auxiliary, genotype)
  logger.info(f"In channels: {c - 1}, model size={count_parameters_in_MB(model)}M")
  train_acc, valid_acc = train(config, model, is_test=True, quiet=False)

  logger.info(f"Final train acc={train_acc}, test acc={valid_acc}")


if __name__ == '__main__':
  main()

