import os
import os.path as osp
import sys
import time
import glob
import numpy as np
import torch

import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch import cat
import pickle
from sklearn.metrics import f1_score

from torch.autograd import Variable
# import utils
# from model_search import Network
# from architect import Architect
# from utils import gen_uniform_60_20_20_split, save_load_split
# from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
# from logging_util import init_logger
from torch_geometric.data import DataLoader

from torch_geometric.datasets import Planetoid, Amazon, Coauthor, CoraFull, Reddit, PPI

from torch_geometric.utils import add_self_loops
from sklearn.model_selection import StratifiedKFold


import os
import argparse
from easydict import EasyDict
from tensorboardX import SummaryWriter
import pprint
import time
import datetime
import torch
import json

import torch.nn.functional as F

from .base_solver import BaseSolver
from prototype.utils.misc import makedir, create_logger, get_logger, AverageMeter, accuracy, load_state_model, load_state_optimizer,\
                         mixup_data, mix_criterion, parse_config, set_seed

from prototype.model import model_entry
from prototype.optimizer import optim_entry
from prototype.lr_scheduler import scheduler_entry

from tqdm import tqdm

try:
    from prototype.nas.sane.controller import GNNController
except ModuleNotFoundError:
    print('prototype.spring.nas not detected yet, install spring module first.')


class SNAESolver(BaseSolver):

    def __init__(self, config_file):
        self.config_file = config_file
        self.config = parse_config(config_file)
        self.setup_env()
        self.build_model()
        self.build_optimizer()
        self.build_data()
        self.build_lr_scheduler()

        # set up NAS controller
        self.controller = GNNController(self.config.nas)
        self.controller.set_supernet(self.model)
        self.controller.set_logger(self.logger)

        self.controller.init_optimizer()

    def setup_env(self):
        # directories
        self.path = EasyDict()
        self.path.root_path = os.path.dirname(self.config_file)
        self.path.save_path = os.path.join(self.path.root_path, 'checkpoints')
        self.path.event_path = os.path.join(self.path.root_path, 'events')
        self.path.result_path = os.path.join(self.path.root_path, 'results')
        makedir(self.path.save_path)
        makedir(self.path.event_path)
        makedir(self.path.result_path)

        self.tb_logger = SummaryWriter(self.path.event_path)

        # logger
        create_logger(os.path.join(self.path.root_path, 'log.txt'))
        self.logger = get_logger(__name__)
        self.logger.info(f'config: {pprint.pformat(self.config)}')

        # load pretrain checkpoint
        if hasattr(self.config.saver, 'pretrain'):
            self.state = torch.load(self.config.saver.pretrain.path, 'cpu')
            self.logger.info(f"Recovering from {self.config.saver.pretrain.path}, keys={list(self.state.keys())}")
            if hasattr(self.config.saver.pretrain, 'ignore'):
                self.state = modify_state(self.state, self.config.saver.pretrain.ignore)
            if 'last_iter' not in self.state:
                self.state['last_iter'] = 0
        else:
            self.state = {}
            self.state['last_iter'] = 0

        # # others
        # torch.backends.cudnn.benchmark = True
        self.seed_base: int = int(self.config.seed_base)
        # set seed
        self.seed: int = self.seed_base
        set_seed(seed=self.seed)

    def build_model(self):
        self.model = model_entry(self.config.model)
        self.model.cuda()

        count_params(self.model)
        count_flops(self.model, input_shape=[
                    1, 3, self.config.data.input_size, self.config.data.input_size])

        if 'model' in self.state:
            load_state_model(self.model, self.state['model'])
        else:
            load_state_model(self.model, self.state)

    def build_optimizer(self):
        opt_config = self.config.optimizer
        opt_config.kwargs.lr = self.config.lr_scheduler.kwargs.base_lr

        # make param_groups
        pconfig = {}

        if opt_config.get('no_wd', False):
            pconfig['conv_b'] = {'weight_decay': 0.0}
            pconfig['linear_b'] = {'weight_decay': 0.0}
            pconfig['bn_w'] = {'weight_decay': 0.0}
            pconfig['bn_b'] = {'weight_decay': 0.0}

        if 'pconfig' in opt_config:
            pconfig.update(opt_config['pconfig'])

        param_group, type2num = param_group_all(self.model, pconfig)

        opt_config.kwargs.params = param_group

        self.optimizer = optim_entry(opt_config)
        if 'optimizer' in self.state:
            load_state_optimizer(self.optimizer, self.state['optimizer'])

    def build_lr_scheduler(self):
        # if max_epoch is set for cfg_datase, build_data function will tranfer max_epoch into max_iter
        if not getattr(self.config.lr_scheduler.kwargs, 'max_iter', False):
            self.config.lr_scheduler.kwargs.max_iter = self.config.data.max_iter
        self.config.lr_scheduler.kwargs.optimizer = self.optimizer
        self.config.lr_scheduler.kwargs.last_iter = self.state['last_iter']
        self.lr_scheduler = scheduler_entry(self.config.lr_scheduler)

    def build_data(self):
        pass
        # self.train_data = {'loader': train_loader}
        # self.val_data = {'loader': val_loader}

    def build_data(self):
        """
        Specific for Cora/PPI
        """
        from torch_geometric.data import DataLoader
        from torch_geometric.datasets import Planetoid, Amazon, Coauthor, CoraFull, Reddit, PPI

        if self.config.data.task == 'Cora':
            pass
        elif self.config.data.task == 'PPI':
            train_dataset = PPI('../data/PPI', split='train')
            val_dataset = PPI('../data/PPI', split='val')
            test_dataset = PPI('../data/PPI', split='test')
        else:
            raise RuntimeError('unknown task: {}'.format(self.config.data.task))

        train_loader = DataLoader(train_dataset, batch_size=self.config.data.train.batch_size, shuffle=self.config.data.train.shuffle)
        val_loader = DataLoader(val_dataset, batch_size=self.config.data.val.batch_size, shuffle=self.config.data.val.shuffle)
        test_loader = DataLoader(test_dataset, batch_size=self.config.data.test.batch_size, shuffle=self.config.data.test.shuffle)

        self.train_data = {'loader': train_loader}
        self.val_data = {'loader': val_loader}
        self.test_data = {'loader': test_loader}
        # train_sampler = DistributedSampler(train_dataset, round_up=False)
        # val_sampler = DistributedSampler(val_dataset, round_up=False)

        # train_loader = DataLoader(
        #     train_dataset,
        #     batch_size=self.config.data.batch_size,
        #     shuffle=False,
        #     num_workers=self.config.data.num_workers,
        #     pin_memory=True,
        #     sampler=train_sampler)

        # val_loader = DataLoader(
        #     val_dataset,
        #     batch_size=self.config.data.batch_size,
        #     shuffle=False,
        #     num_workers=self.config.data.num_workers,
        #     pin_memory=True,
        #     sampler=val_sampler)

    def pre_train(self):
        self.meters = EasyDict()
        self.meters.batch_time = AverageMeter(self.config.saver.print_freq)
        self.meters.step_time = AverageMeter(self.config.saver.print_freq)
        self.meters.data_time = AverageMeter(self.config.saver.print_freq)
        self.meters.losses = AverageMeter(self.config.saver.print_freq)
        self.meters.top1 = AverageMeter(self.config.saver.print_freq)
        self.meters.top5 = AverageMeter(self.config.saver.print_freq)

        self.model.train()

        label_smooth = self.config.get('label_smooth', 0.0)
        self.num_classes = self.config.model.kwargs.get('num_classes', 1000)
        self.topk = 5 if self.num_classes >= 5 else self.num_classes
        self.criterion = torch.nn.CrossEntropyLoss()

        self.mixup = self.config.get('mixup', 1.0)
        if self.mixup < 1.0:
            self.logger.info('using mixup with alpha of: {}'.format(self.mixup))

        # share same criterion with controller
        self.controller.set_criterion(self.criterion)

    def train(self):
        self.pre_train()
        iter_per_epoch = len(self.train_data['loader'])
        total_step = iter_per_epoch * self.config.data.max_epoch
        start_step = self.state['last_iter'] + 1
        end = time.time()

        if self.controller.valid_before_train:
            self.evaluate_specific_subnets(start_step, total_step)

        for epoch in tqdm(range(0, self.config.data.max_epoch)):
            self.train_data['loader'].sampler.set_epoch(epoch)
            start_step = epoch * iter_per_epoch

            if start_step < self.state['last_iter']:
                continue

            for i, data in enumerate(self.train_data['loader']):
                curr_step = start_step + i

                # jumping over trained steps
                if curr_step < start_step:
                    continue

                # get_data
                inp, target = data['input'], data['target']
                target = target.squeeze().long()

                # TODO transfer input data to device
                # target = target.squeeze().cuda().long()
                # input = input.cuda()
                

                # TODO
                # need a split() function to get train_data and val_data
                train_data, val_data = split(data)

                self.lr_scheduler.step(curr_step)
                # lr_scheduler.get_lr()[0] is the main lr
                current_lr = self.lr_scheduler.get_lr()[0]
                # measure data loading time
                self.meters.data_time.update(time.time() - end)

                # architecture step for optizing alpha
                # architect.step(data.to(device), lr, optimizer, unrolled=args.unrolled)
                self.controller.step(train_data, val_data, current_lr, self.optimizer)

                # forward
                logits = self.model(inp)

                # clear gradient
                self.optimizer.zero_grad()

                # compute and update gradient
                loss = self.criterion(logits, target)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(logits, target, topk=(1, self.topk))

                reduced_loss = loss.clone()
                reduced_prec1 = prec1.clone()
                reduced_prec5 = prec5.clone()

                self.meters.losses.reduce_update(reduced_loss)
                self.meters.top1.reduce_update(reduced_prec1)
                self.meters.top5.reduce_update(reduced_prec5)

                # compute and update gradient
                loss.backward()

                # Clip Grad norm
                nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

                # compute and update gradient
                self.optimizer.step()

                # measure elapsed time
                self.meters.batch_time.update(time.time() - end)

            # training logger
            if curr_step % self.config.saver.print_freq == 0 and self.dist.rank == 0:
                self.controller.show_subnet_log()

                self.tb_logger.add_scalar('loss_train', self.meters.losses.avg, curr_step)
                self.tb_logger.add_scalar('acc1_train', self.meters.top1.avg, curr_step)
                self.tb_logger.add_scalar('acc5_train', self.meters.top5.avg, curr_step)
                self.tb_logger.add_scalar('lr', current_lr, curr_step)

                remain_secs = (total_step - curr_step) * self.meters.batch_time.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs))
                log_msg = f'Iter: [{curr_step}/{total_step}]\t' \
                          f'Time {self.meters.batch_time.val:.3f} ({self.meters.batch_time.avg:.3f})\t' \
                          f'Data {self.meters.data_time.val:.3f} ({self.meters.data_time.avg:.3f})\t' \
                          f'Loss {self.meters.losses.val:.4f} ({self.meters.losses.avg:.4f})\t' \
                          f'Prec@1 {self.meters.top1.val:.3f} ({self.meters.top1.avg:.3f})\t' \
                          f'Prec@5 {self.meters.top5.val:.3f} ({self.meters.top5.avg:.3f})\t' \
                          f'LR {current_lr:.6f}\t' \
                          f'Remaining Time {remain_time} ({finish_time})'
                self.logger.info(log_msg)

            # testing during training
            # if curr_step > 0 and curr_step % self.config.saver.val_freq == 0:
            if curr_step > 0 and (epoch + 1) % self.config.saver.val_epoch_freq == 0:
                metrics = self.evaluate()

                loss_val = metrics['loss']
                prec1, prec5 = metrics['top1'], metrics['top5']

                # testing logger
                self.tb_logger.add_scalar('loss_val', loss_val, curr_step)
                self.tb_logger.add_scalar('acc1_val', prec1, curr_step)
                self.tb_logger.add_scalar('acc5_val', prec5, curr_step)

                # save ckpt
                if self.config.saver.save_many:
                    ckpt_name = f'{self.path.save_path}/ckpt_{curr_step}.pth.tar'
                else:
                    ckpt_name = f'{self.path.save_path}/ckpt.pth.tar'
                self.state['model'] = self.model.state_dict()
                self.state['optimizer'] = self.optimizer.state_dict()
                self.state['last_iter'] = curr_step
                torch.save(self.state, ckpt_name)

            end = time.time()
 
    # @torch.no_grad()
    def evaluate(self):
        batch_time = AverageMeter(0)
        losses = AverageMeter(0)
        top1 = AverageMeter(0)
        top5 = AverageMeter(0)

        self.model.eval()
        criterion = torch.nn.CrossEntropyLoss()
        val_iter = len(self.val_data['loader'])
        end = time.time()

        for i, data in enumerate(self.val_data['loader']):
            # get_data
            inp, target = data['input'], data['target']
            target = target.squeeze().long()
            # to_device()
            # target = target.squeeze().view(-1).cuda().long()
            # input = input.cuda()

            logits = self.model(inp)

            # measure accuracy and record loss
            # / world_size # loss should not be scaled here, it's reduced later!
            loss = criterion(logits, target)
            prec1, prec5 = accuracy(logits.data, target, topk=(1, 5))
            num = inp.size(0)
            losses.update(loss.item(), num)
            top1.update(prec1.item(), num)
            top5.update(prec5.item(), num)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if (i+1) % self.config.saver.print_freq == 0:
                self.logger.info(f'Test: [{i+1}/{val_iter}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})')

        # gather final results
        total_num = torch.Tensor([losses.count])
        loss_sum = torch.Tensor([losses.avg*losses.count])
        top1_sum = torch.Tensor([top1.avg*top1.count])
        top5_sum = torch.Tensor([top5.avg*top5.count])

        final_loss = loss_sum.item()/total_num.item()
        final_top1 = top1_sum.item()/total_num.item()
        final_top5 = top5_sum.item()/total_num.item()

        self.logger.info(f' * Prec@1 {final_top1:.3f}\tPrec@5 {final_top5:.3f}\t\
            Loss {final_loss:.3f}\ttotal_num={total_num.item()}')

        self.model.train()
        metrics = {}
        metrics['loss'] = final_loss
        metrics['top1'] = final_top1
        metrics['top5'] = final_top5
        return metrics

    # 测试一个特定的子网，配置从self.subnet里面取
    def evaluate_subnet(self):
        pass

    # finetune一个特定的子网，配置从self.subnet里面取
    def finetune_subnet(self):
        pass


def main():
    parser = argparse.ArgumentParser(description='Graph Neural archtecture search Solver')
    parser.add_argument('--config', required=True, type=str)
    parser.add_argument('--phase', default='train_supnet')

    args = parser.parse_args()
    # build solver
    solver = SNAESolver(args.config)

    # evaluate or fintune or train_search
    if args.phase in ['evaluate_subnet', 'finetune_subnet']:
        if not hasattr(solver.config.saver, 'pretrain'):
            solver.logger.warn(f'{args.phase} without resuming any solver checkpoints.')
        if args.phase == 'evaluate_subnet':
            solver.evaluate_subnet()
        else:
            solver.finetune_subnet()
    elif args.phase == 'train_search':
        if solver.config.data.last_iter <= solver.config.data.max_iter:
            solver.train()
        else:
            solver.logger.info('Training has been completed to max_iter!')
    else:
        raise NotImplementedError

if __name__ == '__main__':
    main()
