import os
import json
import time
import pprint
import argparse
import datetime
from tqdm import tqdm
from easydict import EasyDict
from tensorboardX import SummaryWriter
from sklearn.metrics import f1_score

import torch
from torch.autograd import Variable
from torch_geometric.data import DataLoader
from torch_geometric.datasets import PPI

from torch_geometric.datasets import Planetoid, Amazon, Coauthor, CoraFull, Reddit, PPI
from torch_geometric.utils import add_self_loops

from .snae_solver import SNAESolver
from prototype.utils.misc import makedir, create_logger, get_logger, AverageMeter, accuracy, load_state_model, load_state_optimizer,\
                         parse_config, set_seed, param_group_all, modify_state, save_load_split, gen_uniform_60_20_20_split, load_state_variable, gen_uniform_80_80_20_split
from prototype.model import model_entry
from prototype.optimizer import optim_entry
from prototype.lr_scheduler import scheduler_entry

try:
    from prototype.nas import controller_entry
except ModuleNotFoundError:
    print('prototype.nas not detected yet, install nas module first.')

class SPNASPPISolver(SNAESolver):

    def __init__(self, config_file):
        self.config_file = config_file
        self.config = parse_config(config_file)
        self.setup_env()
        self.build_model()
        self.build_optimizer()
        self.build_lr_scheduler()
        self.build_data()

        # set up NAS controller
        self.controller = controller_entry(self.config.nas)
        self.controller.set_supernet(self.model)
        self.controller.set_logger(self.logger)

        self.controller.init_optimizer()

    def setup_env(self):
        # directories
        self.path = EasyDict()
        self.path.root_path = os.path.dirname(self.config_file)
        self.path.save_path = os.path.join(self.path.root_path, 'checkpoints')
        self.path.event_path = os.path.join(self.path.root_path, 'events')
        self.path.result_path = os.path.join(self.path.root_path, 'results')
        makedir(self.path.save_path)
        makedir(self.path.event_path)
        makedir(self.path.result_path)

        self.tb_logger = SummaryWriter(self.path.event_path)

        # logger
        create_logger(os.path.join(self.path.root_path, 'log.txt'))
        self.logger = get_logger(__name__)
        self.logger.info(f'config: {pprint.pformat(self.config)}')

        # load pretrain checkpoint
        if hasattr(self.config.saver, 'pretrain'):
            self.state = torch.load(self.config.saver.pretrain.path, 'cpu')
            self.logger.info(f"Recovering from {self.config.saver.pretrain.path}, keys={list(self.state.keys())}")
            if hasattr(self.config.saver.pretrain, 'ignore'):
                self.state = modify_state(self.state, self.config.saver.pretrain.ignore)
            if 'last_iter' not in self.state:
                self.state['last_iter'] = 0
            if 'last_epoch' not in self.state:
                self.state['last_epoch'] = -1
        else:
            self.state = {}
            self.state['last_iter'] = 0
            self.state['last_epoch'] = -1

        # # others
        # torch.backends.cudnn.benchmark = True
        self.seed_base: int = int(self.config.seed_base)
        # set seed
        self.seed: int = self.seed_base
        set_seed(seed=self.seed)

    def build_model(self):
        self.model = model_entry(self.config.model)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])
        else:
            pass
            # load_state_model(self.model, self.state)

        if getattr(self.model, 'arch_parameters', False) and 'arch_parameters' in self.state:
            arch_parameters = self.model.arch_parameters()
            for i, state in enumerate(self.state['arch_parameters']):
                load_state_variable(arch_parameters[i], state)

    def _build_optimizer(self, opt_config, model):
        # make param_groups
        pconfig = {}

        if opt_config.get('no_wd', False):
            pconfig['conv_b'] = {'weight_decay': 0.0}
            pconfig['linear_b'] = {'weight_decay': 0.0}
            pconfig['bn_w'] = {'weight_decay': 0.0}
            pconfig['bn_b'] = {'weight_decay': 0.0}

        if 'pconfig' in opt_config:
            pconfig.update(opt_config['pconfig'])

        param_group, type2num = param_group_all(model, pconfig)
        opt_config.kwargs.params = param_group
        return optim_entry(opt_config)

    def build_optimizer(self):
        self.optimizer = self._build_optimizer(self.config.optimizer, self.model)
        if 'optimizer' in self.state:
            load_state_optimizer(self.optimizer, self.state['optimizer'])

    def _build_lr_scheduler(self, lr_config, optimizer):
        lr_config.kwargs.optimizer = optimizer
        lr_config.kwargs.last_epoch = self.state['last_epoch']
        return scheduler_entry(lr_config)

    def build_lr_scheduler(self):
        self.lr_scheduler = self._build_lr_scheduler(self.config.lr_scheduler, self.optimizer)

    def build_data(self):
        """
        Specific for PPI
        """
        if not getattr(self.config.data, 'max_epoch', False):
            self.config.data.max_epoch = self.config.lr_scheduler.kwargs.T_max

        train_dataset = PPI('/pathtodata/PPI', split='train')
        val_dataset = PPI('/pathtodata/PPI', split='val')
        test_dataset = PPI('/pathtodata/PPI', split='test')

        train_loader = DataLoader(train_dataset, batch_size=self.config.data.train.batch_size, shuffle=self.config.data.train.shuffle)
        val_loader = DataLoader(val_dataset, batch_size=self.config.data.val.batch_size, shuffle=self.config.data.val.shuffle)
        test_loader = DataLoader(test_dataset, batch_size=self.config.data.test.batch_size, shuffle=self.config.data.test.shuffle)

        self.train_data = {'loader': train_loader}
        self.val_data = {'loader': val_loader}
        self.test_data = {'loader': test_loader}

    def build_finetune_dataset(self, max_epoch=None):
        """
        Specific for PPI
        """
        if max_epoch is not None:
            self.config.data.max_epoch = max_epoch

        self.build_data()

    def _pre_train(self, model):
        self.meters = EasyDict()
        self.meters.batch_time = AverageMeter(self.config.saver.print_freq)
        self.meters.step_time = AverageMeter(self.config.saver.print_freq)
        self.meters.data_time = AverageMeter(self.config.saver.print_freq)
        self.meters.losses = AverageMeter(self.config.saver.print_freq)
        self.meters.top1 = AverageMeter(self.config.saver.print_freq)
        self.meters.top5 = AverageMeter(self.config.saver.print_freq)

        model.train()

        self.num_classes = self.config.model.kwargs.get('out_dim', 1000)
        self.topk = 5 if self.num_classes >= 5 else self.num_classes
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.mixup = self.config.get('mixup', 1.0)
        if self.mixup < 1.0:
            self.logger.info('using mixup with alpha of: {}'.format(self.mixup))

        # share same criterion with controller
        self.controller.set_criterion(self.criterion)

    def _train(self, model):
        self._pre_train(model=model)
        model.eval()

        iter_per_epoch = len(self.train_data['loader'])
        total_step = iter_per_epoch * self.config.data.max_epoch
        end = time.time()

        best_prec1_val , best_prec1_test = 0, 0

        for epoch in tqdm(range(0, self.config.data.max_epoch)):
            start_step = epoch * iter_per_epoch

            if start_step < self.state['last_iter']:
                continue

            self.lr_scheduler.step()
            # lr_scheduler.get_lr()[0] is the main lr
            current_lr = self.lr_scheduler.get_lr()[0]

            curr_step = start_step

            # architecture step for optizing alpha for one epoch
            self.controller.step(self.train_data['loader'], self.val_data['loader'], current_lr, self.optimizer)

            # measure elapsed time
            self.meters.batch_time.update(time.time() - end)

            # training logger
            if curr_step % 1 == 0:
                if getattr(model, 'arch_parameters', False):
                    self.tb_logger.add_histogram('na_alphas', model.na_alphas, curr_step)
                    self.tb_logger.add_histogram('sc_alphas', model.sc_alphas, curr_step)
                    self.tb_logger.add_histogram('la_alphas', model.la_alphas, curr_step)

                remain_secs = (total_step - curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs))
                log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                        f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                        f'Remaining Time {remain_time} ({finish_time})'
                self.logger.info(log_msg)

            end = time.time()

        # testing After training
        if curr_step >= 0 and (epoch + 1) % self.config.saver.val_epoch_freq == 0:
            metrics = self._validate(model=model)
            loss_val = metrics['loss']
            prec1_val = metrics['top1']

            metrics = self._evaluate(model=model)
            loss_test = metrics['loss']
            prec1_test = metrics['top1']

            # recording best accuracy performance based on validation accuracy
            if prec1_val > best_prec1_val:
                best_prec1_val = prec1_val
                best_prec1_test = prec1_test

            # testing logger
            self.tb_logger.add_scalar('loss_val', loss_val, curr_step)
            self.tb_logger.add_scalar('acc1_val', prec1_val, curr_step)
            self.tb_logger.add_scalar('loss_test', loss_test, curr_step)
            self.tb_logger.add_scalar('acc1_test', prec1_test, curr_step)

            # save ckpt
            if self.config.saver.save_many:
                ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
            else:
                ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'

            self.state['model'] = model.state_dict()
            self.state['optimizer'] = self.optimizer.state_dict()
            self.state['last_epoch'] = epoch
            self.state['last_iter'] = curr_step
            if getattr(model, 'arch_parameters', False):
                self.state['arch_parameters'] = model.arch_parameters()

            torch.save(self.state, ckpt_name)
            genotype = model.genotype()
            self.logger.info('genotype = %s', genotype)

        res = []
        res.append(f'genotype={self.model.genotype()}')

        result_filename =  os.path.join(self.path.result_path,
                        f'searched_result.txt')
        with open(result_filename, 'w+') as file:
            file.write('\n'.join(res))
            file.close()

        self.logger.info('searched res for {} saved in {}'.format(self.config.data.task, result_filename))

        metrics = {}
        metrics['best_top1_val'] = best_prec1_val
        metrics['best_top1_test'] = best_prec1_test

        return metrics

    @torch.no_grad()
    def _evaluate(self, model):
        batch_time = AverageMeter(0)
        losses = AverageMeter(0)
        top1 = AverageMeter(0)

        model.eval()
        criterion = torch.nn.BCEWithLogitsLoss()
        val_iter = len(self.test_data['loader'])
        end = time.time()

        for i, data in enumerate(self.test_data['loader']):
            # get_data
            inp, target = data, Variable(data.y)

            logits = model(inp)

            # measure f1_score and record loss
            loss = criterion(logits, target)
            # prec1, prec5 = f1_score(logits.data, target, topk=(1, 5))
            prec1 = f1_score(target, (logits > 0).float(), average='micro')
            num = inp.size(0)
            losses.update(loss.item(), num)
            top1.update(prec1.item(), num)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if (i+1) % self.config.saver.print_freq == 0:
                self.logger.info(f'Test: [{i+1}/{val_iter}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})')

        # gather final results
        total_num = torch.Tensor([losses.count])
        loss_sum = torch.Tensor([losses.avg*losses.count])
        top1_sum = torch.Tensor([top1.avg*top1.count])

        final_loss = loss_sum.item()/total_num.item()
        final_top1 = top1_sum.item()/total_num.item()

        self.logger.info(f' * Prec@1 {final_top1:.3f}\t \
            Loss {final_loss:.3f}\ttotal_num={total_num.item()}')

        model.train()
        metrics = {}
        metrics['loss'] = final_loss
        metrics['top1'] = final_top1
        return metrics

    @torch.no_grad()
    def _validate(self, model):
        batch_time = AverageMeter(0)
        losses = AverageMeter(0)
        top1 = AverageMeter(0)

        model.eval()
        criterion = torch.nn.BCEWithLogitsLoss()
        val_iter = len(self.val_data['loader'])
        end = time.time()

        for i, data in enumerate(self.val_data['loader']):
            # get_data
            inp, target = data, Variable(data.y)

            logits = model(inp)

            # measure f1_score and record loss
            loss = criterion(logits, target)
            # prec1, prec5 = f1_score(logits.data, target, topk=(1, 5))
            prec1 = f1_score(target, (logits > 0).float(), average='micro')
            num = inp.size(0)
            losses.update(loss.item(), num)
            top1.update(prec1.item(), num)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if (i+1) % self.config.saver.print_freq == 0:
                self.logger.info(f'Test: [{i+1}/{val_iter}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})')

        # gather final results
        total_num = torch.Tensor([losses.count])
        loss_sum = torch.Tensor([losses.avg*losses.count])
        top1_sum = torch.Tensor([top1.avg*top1.count])

        final_loss = loss_sum.item()/total_num.item()
        final_top1 = top1_sum.item()/total_num.item()

        self.logger.info(f' * Prec@1 {final_top1:.3f}\t \
            Loss {final_loss:.3f}\ttotal_num={total_num.item()}')

        model.train()
        metrics = {}
        metrics['loss'] = final_loss
        metrics['top1'] = final_top1
        return metrics

    # 测试一个超网，配置从self.config里面取
    def train(self):
        self._train(model=self.model)

    def evaluate(self):
        self._evaluate(model=self.model)
        prob_result_message = self.model.get_prob_result()
        for line in prob_result_message.split('\n'):
            self.logger.info(line)

    # 测试一个特定的子网，配置从self.subnet里面取
    def evaluate_subnet(self):
        self.subnet = self.controller.subnet
        assert self.subnet is not None

        self.save_subnet_weight = self.subnet.get('save_subnet_weight', False)

        if not getattr(self.subnet.subnet_settings, 'genotype', False):
            if getattr(self.subnet.subnet_settings, 'genotype_filename', False):
                result_line = open(self.subnet.subnet_settings.arch_filename, 'r').readlines()[-1]
                genotype = result_line.split('=')[1]
                self.subnet.subnet_settings.genotype = genotype
            else:
                # sample arch if self.subnet.subnet_settings.arch and arch_file is None
                subnet_settings = self.controller.sample_subnet_settings(sample_mode='random')
        else:
            subnet_settings = self.subnet.subnet_settings

        # build subnet from supernet
        subnet_model = self.controller.build_active_subnet(subnet_settings)

        # evaluate
        metrics = self._evaluate(model=subnet_model)

        # evaluate logging
        top1 = round(metrics['top1'], 3)
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1': top1}
        self.logger.info('Subnet with settings: {}\ttop1 {}'.format(subnet_settings, top1))
        self.logger.info('Evaluate_subnet\t{}'.format(json.dumps(subnet)))

        # save weights
        if self.save_subnet_weight:
            state_dict = {}
            state_dict['model'] = subnet_model.state_dict()
            ckpt_name = f'{self.path.bignas_path}/ckpt_{top1}.pth.tar'
            torch.save(state_dict, ckpt_name)
        return subnet

    # finetune一个特定的子网，配置从self.subnet里面取
    def finetune_subnet(self):
        self.subnet = self.controller.subnet
        assert self.subnet is not None
        assert self.subnet.subnet_settings is not None

        if not getattr(self.subnet.subnet_settings, 'genotype', False):
            # read genotype from genotype_filename if self.subnet.subnet_settings.genotype_filename 
            # and genotype_filename is None
            if getattr(self.subnet.subnet_settings, 'genotype_filename', False):
                result_line = open(self.subnet.subnet_settings.genotype_filename, 'r').readlines()[-1]
                genotype = result_line.split('=')[1]
                self.subnet.subnet_settings.genotype = genotype
                subnet_settings = self.subnet.subnet_settings
            else:
                # sample arch if self.subnet.subnet_settings.genotype and genotype_filename is None
                subnet_settings = self.controller.sample_subnet_settings(sample_mode='random')
        else:
            subnet_settings = self.subnet.subnet_settings

        # build subnet from supernet
        subnet_model = self.controller.build_active_subnet(subnet_settings)

        # rebuild optimizer and scheduler
        self.optimizer = self._build_optimizer(self.subnet.optimizer, subnet_model)
        self.lr_scheduler = self._build_lr_scheduler(self.subnet.lr_scheduler, self.optimizer)
        self.build_finetune_dataset(max_epoch=self.subnet.data.max_epoch)

        # valiadate
        metrics = self._validate(model=subnet_model)
        top1_val = round(metrics['top1'], 3)

        # evaluate
        metrics = self._evaluate(model=subnet_model)
        top1_test = round(metrics['top1'], 3)

        # finetuneing logging
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1_val': top1_val, 'top1_test': top1_test}
        self.logger.info('Before finetune subnet {}'.format(json.dumps(subnet)))

        # finetune restart
        last_iter = self.state['last_iter']
        last_epoch = self.state['last_epoch']
        self.state['last_iter'] = 0
        self.state['last_epoch'] = 0

        # finetuning 
        metrics = self._train(model=subnet_model)
        best_top1_val = metrics['best_top1_val']
        best_top1_test = metrics['best_top1_test']

        # record finetuning iterations
        self.state['last_iter'] = last_iter
        self.state['last_epoch'] = last_epoch

        # valiadate
        metrics = self._validate(model=subnet_model)
        top1_val = round(metrics['top1'], 3)

        # evaluate
        metrics = self._evaluate(model=subnet_model)
        top1_test = round(metrics['top1'], 3)

        # finetuneing logging
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1_val': top1_val, 'top1_test': top1_test, 'best_top1_val': best_top1_val, 'best_top1_test': best_top1_test}
        self.logger.info('After finetune subnet {}'.format(json.dumps(subnet)))
        return subnet

    def hpo_finetune_subnet(self):
        import hyperopt
        from hyperopt import fmin, tpe, hp, Trials, partial, STATUS_OK

        def generate_args(hpo_cfg):
            hpo_cfg = EasyDict(hpo_cfg)
            self.controller.subnet.optimizer.type = hpo_cfg.optimizer
            self.controller.subnet.optimizer.kwargs.lr = 10**hpo_cfg.learning_rate
            self.controller.subnet.optimizer.kwargs.weight_decay = 10**hpo_cfg.weight_decay
            self.controller.subnet.subnet_settings.hidden_sizer = hpo_cfg.hidden_size
            self.controller.subnet.subnet_settings.in_dropout = hpo_cfg.in_dropout / 10
            self.controller.subnet.subnet_settings.out_dropout = hpo_cfg.out_dropout / 10
            self.controller.subnet.subnet_settings.act = hpo_cfg.activation
            return hpo_cfg

        def objective(hpo_cfg):
            generate_args(hpo_cfg)
            subnet = self.finetune_subnet()
            val_acc = subnet['best_top1_val']
            test_acc = subnet['best_top1_test']
            return {
                'loss': -val_acc,
                'test_acc': test_acc,
                'status': STATUS_OK,
                'eval_time': round(time.time(), 2),
            }

        sane_space = {'model': 'SANE',
                'hidden_size': hp.choice('hidden_size', [16, 32, 64, 128, 256]),
                'learning_rate': hp.uniform("lr", -3, -2),
                'weight_decay': hp.uniform("wr", -5, -3),
                'optimizer': hp.choice('opt', ['Adagrad', 'Adam']),
                'in_dropout': hp.choice('in_dropout', [0, 1, 2, 3, 4, 5, 6]),
                'out_dropout': hp.choice('out_dropout', [0, 1, 2, 3, 4, 5, 6]),
                'activation': hp.choice('act', ['relu', 'elu'])
                }

        if self.config.data.task == 'PubMed':
            sane_space['hidden_size'] = hp.choice('hidden_size', [16, 32, 64])
        elif self.config.data.task == 'PPI':
            sane_space['learning_rate'] = hp.uniform("lr", -3, -1.6)
            sane_space['in_dropout'] = hp.choice('in_dropout', [0, 1])
            sane_space['out_dropout'] = hp.choice('out_dropout', [0, 1])
            sane_space['hidden_size'] = hp.choice('hidden_size', [64, 128, 256, 512, 1024])
        elif self.config.data.task == 'CiteSeer':
            sane_space['learning_rate'] = hp.uniform("lr", -2.5, -1.6)
            sane_space['weight_decay'] = hp.choice('wr', [-8])
            sane_space['in_dropout'] = hp.choice('in_dropout', [5])
            sane_space['out_dropout'] = hp.choice('out_dropout', [0])

        try:
            start = time.time()
            trials = Trials()

            # tune with validation acc, and report the test accuracy with the best validation acc
            best_trail = fmin(objective, sane_space, algo=partial(tpe.suggest, n_startup_jobs=int(self.config.hpo.hyper_epoch/5)),
                        max_evals=self.config.hpo.hyper_epoch, trials=trials)

            hpo_cfg = hyperopt.space_eval(sane_space, best_trail)
            self.logger.info(f'Best Config from HPO Search Space is {hpo_cfg}')

            record_time_res = []
            t_val_acc, t_test_acc = 0, 0

            # report the test acc with the best val acc
            for trail_result in trials.results:
                if -trail_result['loss'] > t_val_acc:
                    t_val_acc = -trail_result['loss']
                    t_test_acc = trail_result['test_acc']
                    record_time_res.append('%s,%s,%s' % (trail_result['eval_time'] - start, t_val_acc, t_test_acc))
            self.logger.info(f'Best_test_acc={t_test_acc}')
        except Exception as e:
            self.logger.info('errror occured , error=%s', e)
            import traceback
            traceback.print_exc()

def main():
    parser = argparse.ArgumentParser(description='Graph Neural archtecture search Solver')
    parser.add_argument('--config', required=True, type=str)
    parser.add_argument('--phase', default='train_search')

    args = parser.parse_args()
    # build solver
    solver = SPNASPPISolver(args.config)

    # evaluate or fintune or train_search
    if args.phase in ['hpo_finetune_subnet']:
        solver.hpo_finetune_subnet()
    elif args.phase in ['evaluate_subnet', 'finetune_subnet']:
        if not hasattr(solver.config.saver, 'pretrain'):
            solver.logger.warn(f'{args.phase} without resuming any solver checkpoints.')
        if args.phase == 'evaluate_subnet':
            solver.evaluate_subnet()
        else:
            solver.finetune_subnet()
    elif args.phase == 'train_search':
        if solver.state['last_epoch'] <= solver.config.data.max_epoch:
            solver.train()
        else:
            solver.logger.info('Training has been completed to max_epoch!')
    elif args.phase == 'evaluate':
        solver.evaluate()
    elif args.phase == 'sample_accuracy':
        solver.sample_multiple_subnet_accuracy()
    else:
        raise NotImplementedError

if __name__ == '__main__':
    main()