import os
import json
import time
import pprint
import argparse
import datetime
from tqdm import tqdm
from easydict import EasyDict
from tensorboardX import SummaryWriter
from sklearn.metrics import f1_score

import torch
from torch.autograd import Variable
from torch_geometric.data import DataLoader
from torch_geometric.datasets import PPI

from .snae_solver import SNAESolver
from prototype.utils.misc import makedir, create_logger, get_logger, AverageMeter, accuracy, load_state_model, load_state_optimizer,\
                         parse_config, set_seed, param_group_all, modify_state
from prototype.model import model_entry
from prototype.optimizer import optim_entry
from prototype.lr_scheduler import scheduler_entry

try:
    from prototype.nas.sane.controller import GNNController
except ModuleNotFoundError:
    print('prototype.spring.nas not detected yet, install spring module first.')

import pdb

class SNAESolver(SNAESolver):

    def __init__(self, config_file):
        self.config_file = config_file
        self.config = parse_config(config_file)
        self.setup_env()
        self.build_model()
        self.build_optimizer()
        self.build_lr_scheduler()
        self.build_data()

        # set up NAS controller
        self.controller = GNNController(self.config.nas)
        self.controller.set_supernet(self.model)
        self.controller.set_logger(self.logger)

        self.controller.init_optimizer()

    def setup_env(self):
        # directories
        self.path = EasyDict()
        self.path.root_path = os.path.dirname(self.config_file)
        self.path.save_path = os.path.join(self.path.root_path, 'checkpoints')
        self.path.event_path = os.path.join(self.path.root_path, 'events')
        self.path.result_path = os.path.join(self.path.root_path, 'results')
        makedir(self.path.save_path)
        makedir(self.path.event_path)
        makedir(self.path.result_path)

        self.tb_logger = SummaryWriter(self.path.event_path)

        # logger
        create_logger(os.path.join(self.path.root_path, 'log.txt'))
        self.logger = get_logger(__name__)
        self.logger.info(f'config: {pprint.pformat(self.config)}')

        # load pretrain checkpoint
        if hasattr(self.config.saver, 'pretrain'):
            self.state = torch.load(self.config.saver.pretrain.path, 'cpu')
            self.logger.info(f"Recovering from {self.config.saver.pretrain.path}, keys={list(self.state.keys())}")
            if hasattr(self.config.saver.pretrain, 'ignore'):
                self.state = modify_state(self.state, self.config.saver.pretrain.ignore)
            if 'last_iter' not in self.state:
                self.state['last_iter'] = 0
            if 'last_epoch' not in self.state:
                self.state['last_epoch'] = -1
        else:
            self.state = {}
            self.state['last_iter'] = 0
            self.state['last_epoch'] = -1

        # # others
        # torch.backends.cudnn.benchmark = True
        self.seed_base: int = int(self.config.seed_base)
        # set seed
        self.seed: int = self.seed_base
        set_seed(seed=self.seed)

    def build_model(self):
        self.model = model_entry(self.config.model)

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])
        else:
            load_state_model(self.model, self.state)

    def _build_optimizer(self, opt_config, model):
        # make param_groups
        pconfig = {}

        if opt_config.get('no_wd', False):
            pconfig['conv_b'] = {'weight_decay': 0.0}
            pconfig['linear_b'] = {'weight_decay': 0.0}
            pconfig['bn_w'] = {'weight_decay': 0.0}
            pconfig['bn_b'] = {'weight_decay': 0.0}

        if 'pconfig' in opt_config:
            pconfig.update(opt_config['pconfig'])

        param_group, type2num = param_group_all(model, pconfig)
        opt_config.kwargs.params = param_group
        return optim_entry(opt_config)

    def build_optimizer(self):
        self.optimizer = self._build_optimizer(self.config.optimizer, self.model)
        if 'optimizer' in self.state:
            load_state_optimizer(self.optimizer, self.state['optimizer'])

    def _build_lr_scheduler(self, lr_config, optimizer):
        lr_config.kwargs.optimizer = optimizer
        lr_config.kwargs.last_epoch = self.state['last_epoch']
        return scheduler_entry(lr_config)

    def build_lr_scheduler(self):
        self.lr_scheduler = self._build_lr_scheduler(self.config.lr_scheduler, self.optimizer)

    def build_data(self):
        """
        Specific for PPI
        """
        if not getattr(self.config.data, 'max_epoch', False):
            self.config.data.max_epoch = self.config.lr_scheduler.kwargs.T_max

        train_dataset = PPI('/root/data/PPI', split='train')
        val_dataset = PPI('/root/data/PPI', split='val')
        test_dataset = PPI('/root/data/PPI', split='test')

        train_loader = DataLoader(train_dataset, batch_size=self.config.data.train.batch_size, shuffle=self.config.data.train.shuffle)
        val_loader = DataLoader(val_dataset, batch_size=self.config.data.val.batch_size, shuffle=self.config.data.val.shuffle)
        test_loader = DataLoader(test_dataset, batch_size=self.config.data.test.batch_size, shuffle=self.config.data.test.shuffle)

        self.train_data = {'loader': train_loader}
        self.val_data = {'loader': val_loader}
        self.test_data = {'loader': test_loader}

    def build_finetune_data(self, max_epoch=None):
        """
        Specific for PPI
        """
        if max_epoch is not None:
            self.config.data.max_epoch = max_epoch

        self.build_data()

    def _pre_train(self, model):
        self.meters = EasyDict()
        self.meters.batch_time = AverageMeter(self.config.saver.print_freq)
        self.meters.step_time = AverageMeter(self.config.saver.print_freq)
        self.meters.data_time = AverageMeter(self.config.saver.print_freq)
        self.meters.losses = AverageMeter(self.config.saver.print_freq)
        self.meters.top1 = AverageMeter(self.config.saver.print_freq)
        self.meters.top5 = AverageMeter(self.config.saver.print_freq)

        model.train()

        label_smooth = self.config.get('label_smooth', 0.0)
        self.num_classes = self.config.model.kwargs.get('num_classes', 1000)
        self.topk = 5 if self.num_classes >= 5 else self.num_classes
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.mixup = self.config.get('mixup', 1.0)
        if self.mixup < 1.0:
            self.logger.info('using mixup with alpha of: {}'.format(self.mixup))

        # share same criterion with controller
        self.controller.set_criterion(self.criterion)

    def _train(self, model):
        self._pre_train(model=model)
        iter_per_epoch = len(self.train_data['loader'])
        total_step = iter_per_epoch * self.config.data.max_epoch
        end = time.time()

        for epoch in tqdm(range(0, self.config.data.max_epoch)):
            start_step = epoch * iter_per_epoch

            if start_step < self.state['last_iter']:
                continue

            self.lr_scheduler.step()
            current_lr = self.lr_scheduler.get_lr()[0]

            # architecture step for optizing alpha for one epoch
            self.controller.step(self.train_data['loader'], self.val_data['loader'], current_lr, self.optimizer, unrolled=self.config.nas.unrolled)

            for i, data in enumerate(self.train_data['loader']):
                curr_step = start_step + i

                # jumping over trained steps
                if curr_step < self.state['last_iter']:
                    continue

                # get_data
                inp, target = data, Variable(data.y)

                # measure data loading time
                self.meters.data_time.update(time.time() - end)

                # forward
                logits = model(inp)

                # clear gradient
                self.optimizer.zero_grad()

                # compute and update gradient
                loss = self.criterion(logits, target)

                # measure accuracy and record loss
                # prec1, prec5 = accuracy(logits, target, topk=(1, self.topk))
                prec1 = f1_score(target, (logits > 0).float(), average='micro')

                reduced_loss = loss.clone()
                reduced_prec1 = prec1

                self.meters.losses.reduce_update(reduced_loss)
                self.meters.top1.reduce_update(reduced_prec1)
                # self.meters.top5.reduce_update(reduced_prec5)

                # compute and update gradient
                loss.backward()

                # Clip Grad norm
                torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip)

                # compute and update gradient
                self.optimizer.step()

                # measure elapsed time
                self.meters.batch_time.update(time.time() - end)

                # training logger
                if curr_step % self.config.saver.print_freq == 0:
                    self.tb_logger.add_scalar('loss_train', self.meters.losses.avg, curr_step)
                    self.tb_logger.add_scalar('acc1_train', self.meters.top1.avg, curr_step)
                    self.tb_logger.add_scalar('lr', current_lr, curr_step)

                    remain_secs = (total_step - curr_step) * self.meters.batch_time.avg
                    remain_time = datetime.timedelta(seconds=round(remain_secs))
                    finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs))
                    log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                            f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                            f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                            f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                            f'Prec@1 {self.meters.top1.val:.3f} ({self.meters.top1.avg:.3f})\t' \
                            f'LR {current_lr:.6f}\t' \
                            f'Remaining Time {remain_time} ({finish_time})'
                    self.logger.info(log_msg)

                end = time.time()

            # testing during training
            # if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
            if curr_step > 0 and (epoch + 1) % self.config.saver.val_epoch_freq == 0:
                metrics = self._evaluate(model=model)

                loss_val = metrics['loss']
                prec1 = metrics['top1']

                # testing logger
                self.tb_logger.add_scalar('loss_val', loss_val, curr_step)
                self.tb_logger.add_scalar('acc1_val', prec1, curr_step)

                # save ckpt
                if self.config.saver.save_many:
                    ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                else:
                    ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                self.state['model'] = model.state_dict()
                self.state['optimizer'] = self.optimizer.state_dict()
                self.state['last_epoch'] = epoch
                self.state['last_iter'] = curr_step
                torch.save(self.state, ckpt_name)
                genotype = model.genotype()
                self.logger.info('genotype = %s', genotype)

        res = []
        res.append(f'genotype={self.model.genotype()}')

        result_filename =  os.path.join(self.path.result_path,
                        f'{self.config.data.task}-searched_res-{time.strftime("%Y%m%d-%H%M%S")}.txt')
        with open(result_filename, 'w+') as file:
            file.write('\n'.join(res))
            file.close()

        self.logger.info('searched res for {} saved in {}'.format(self.config.data.task, result_filename))

    # @torch.no_grad()
    def _evaluate(self, model):
        batch_time = AverageMeter(0)
        losses = AverageMeter(0)
        top1 = AverageMeter(0)

        model.eval()
        criterion = torch.nn.BCEWithLogitsLoss()
        val_iter = len(self.test_data['loader'])
        end = time.time()

        for i, data in enumerate(self.test_data['loader']):
            # get_data
            inp, target = data, Variable(data.y)
            logits = model(inp)

            # measure f1_score and record loss
            # / world_size # loss should not be scaled here, it's reduced later!
            loss = criterion(logits, target)
            # prec1, prec5 = f1_score(logits.data, target, topk=(1, 5))
            prec1 = f1_score(target, (logits > 0).float(), average='micro')
            num = inp.size(0)
            losses.update(loss.item(), num)
            top1.update(prec1.item(), num)
            # top5.update(prec5.item(), num)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if (i+1) % self.config.saver.print_freq == 0:
                self.logger.info(f'Test: [{i+1}/{val_iter}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})')

        # gather final results
        total_num = torch.Tensor([losses.count])
        loss_sum = torch.Tensor([losses.avg*losses.count])
        top1_sum = torch.Tensor([top1.avg*top1.count])
        # top5_sum = torch.Tensor([top5.avg*top5.count])

        final_loss = loss_sum.item()/total_num.item()
        final_top1 = top1_sum.item()/total_num.item()

        self.logger.info(f' * Prec@1 {final_top1:.3f}\t \
            Loss {final_loss:.3f}\ttotal_num={total_num.item()}')

        model.train()
        metrics = {}
        metrics['loss'] = final_loss
        metrics['top1'] = final_top1
        return metrics

    # 测试一个超网，配置从self.config里面取
    def train(self):
        self._train(model=self.model)

    # 测试一个特定的子网，配置从self.subnet里面取
    def evaluate_subnet(self):
        self.subnet = self.controller.subnet
        assert self.subnet is not None

        self.save_subnet_weight = self.subnet.get('save_subnet_weight', False)

        # sample arch if self.subnet.subnet_settings.arch is None
        if self.subnet.subnet_settings.genotype is None:
            subnet_settings = self.controller.sample_subnet_settings(sample_mode='random')
        else:
            subnet_settings = self.subnet.subnet_settings

        # build subnet from supernet
        subnet_model = self.controller.build_active_subnet(subnet_settings)

        # evaluate
        metrics = self._evaluate(model=subnet_model)

        # evaluate logging
        top1 = round(metrics['top1'], 3)
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1': top1}
        self.logger.info('Subnet with settings: {}\ttop1 {}'.format(subnet_settings, top1))
        self.logger.info('Evaluate_subnet\t{}'.format(json.dumps(subnet)))

        # save weights
        if self.save_subnet_weight:
            state_dict = {}
            state_dict['model'] = subnet_model.state_dict()
            ckpt_name = f'{self.path.bignas_path}/ckpt_{top1}.pth.tar'
            torch.save(state_dict, ckpt_name)
        return subnet

    # finetune一个特定的子网，配置从self.subnet里面取
    def finetune_subnet(self):
        self.subnet = self.controller.subnet
        assert self.subnet is not None

        # sample arch if self.subnet.subnet_settings.arch is None
        if self.subnet.subnet_settings.genotype is None:
            subnet_settings = self.controller.sample_subnet_settings(sample_mode='random')
        else:
            subnet_settings = self.subnet.subnet_settings

        # build subnet from supernet
        subnet_model = self.controller.build_active_subnet(subnet_settings)

        # rebuild optimizer and scheduler
        self.optimizer = self._build_optimizer(self.subnet.optimizer, subnet_model)
        self.lr_scheduler = self._build_lr_scheduler(self.subnet.lr_scheduler, self.optimizer)
        self.build_finetune_data(max_epoch=self.subnet.data.max_epoch)

        # evaluate
        metrics = self._evaluate(model=subnet_model)

        # finetuneing logging
        top1 = round(metrics['top1'], 3)
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1': top1}
        self.logger.info('Before finetune subnet {}'.format(json.dumps(subnet)))

        # self.build_subnet_finetune_dataset()

        # finetune restart
        last_iter = self.state['last_iter']
        last_epoch = self.state['last_epoch']
        self.state['last_iter'] = 0
        self.state['last_epoch'] = 0

        # finetuning 
        self._train(model=subnet_model)

        # record finetuning iterations
        self.state['last_iter'] += last_iter
        self.state['last_epoch'] += last_epoch

        # evaluate
        metrics = self._evaluate(model=subnet_model)

        # finetuneing logging
        top1 = round(metrics['top1'], 3)
        subnet = {'subnet_settings': self.subnet.subnet_settings, 'top1': top1}
        self.logger.info('After finetune subnet {}'.format(json.dumps(subnet)))

        return subnet


def main():
    parser = argparse.ArgumentParser(description='Graph Neural archtecture search Solver')
    parser.add_argument('--config', required=True, type=str)
    parser.add_argument('--phase', default='train_search')

    args = parser.parse_args()
    # build solver
    solver = SNAESolver(args.config)

    # evaluate or fintune or train_search
    if args.phase in ['evaluate_subnet', 'finetune_subnet']:
        if not hasattr(solver.config.saver, 'pretrain'):
            solver.logger.warn(f'{args.phase} without resuming any solver checkpoints.')
        if args.phase == 'evaluate_subnet':
            solver.evaluate_subnet()
        else:
            solver.finetune_subnet()
    elif args.phase == 'train_search':
        if solver.state['last_epoch'] <= solver.config.data.max_epoch:
            solver.train()
        else:
            solver.logger.info('Training has been completed to max_epoch!')
    else:
        raise NotImplementedError

if __name__ == '__main__':
    main()