import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import ImageDataset
import numpy as np
from ray import tune
from get_loaders import *
from copy import deepcopy
# from utils import get_linear_schedule_with_warmup
from utils.utils import *
# from utils.inception_score import _init_inception
# from utils.fid_score import create_inception_graph, check_or_download_inception
from functions import train, validate, load_params, copy_params
import gan_models
from easydict import EasyDict


class SNGAN(tune.Trainable):
    def _setup(self, config):
        args = config["args"]
        Opt = config["optimizer"]
        warmup = config["warmup"]
        decay_rate = config["decay_rate"]
        self.task_config = config["task_config"]
        max_t = self.task_config["max_t"]
        self.task_config = EasyDict(self.task_config)
        # set tf env
        # _init_inception()
        # inception_path = check_or_download_inception(None)
        # create_inception_graph(inception_path)

        self.gen_net = eval('gan_models.' + 'sngan_cifar10' + '.Generator')(args=self.task_config).cuda()
        self.dis_net = eval('gan_models.' + 'sngan_cifar10' + '.Discriminator')(args=self.task_config).cuda()
        self.global_step = 0
        # weight init
        def weights_init(m):
            classname = m.__class__.__name__
            if classname.find('Conv2d') != -1:
                if self.task_config.init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, 0.02)
                elif self.task_config.init_type == 'orth':
                    nn.init.orthogonal_(m.weight.data)
                elif self.task_config.init_type == 'xavier_uniform':
                    nn.init.xavier_uniform_(m.weight.data, 1.)
                else:
                    raise NotImplementedError('{} unknown inital type'.format(args.init_type))
            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0.0)

        self.gen_net.apply(weights_init)
        self.dis_net.apply(weights_init)

        # set up data_loader
        dataset = ImageDataset(self.task_config)
        self.train_loader = dataset.train

        # fid stat
        # if args.dataset.lower() == 'cifar10':
        self.fid_stat = '/nfs/data/yhxiong/fid_stat/fid_stats_cifar10_train.npz'
        # elif args.dataset.lower() == 'stl10':
        #     self.fid_stat = '~/fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
        # else:
        #     raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
        # assert os.path.exists(self.fid_stat)

        # epoch number for dis_net
        # args.max_epoch = args.max_epoch * args.n_critic
        # if args.max_iter:
        #     args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(self.train_loader))

        # initial
        self.gen_avg_param = copy_params(self.gen_net)

        # remove args and task_config so that the remaining keys are purely related to optimizer
        del (
            config["args"],
            config["task_config"],
            config["optimizer"],
            config["warmup"],
            config["decay_rate"],
        )

        self.gen_optimizer = Opt(filter(lambda p: p.requires_grad, self.gen_net.parameters()),
                                 **config)
        self.dis_optimizer = Opt(filter(lambda p: p.requires_grad, self.dis_net.parameters()),
                                 **config)
        warmup_iters = int(warmup * len(self.train_loader))
        total_iters = max_t * len(self.train_loader)
        # self.scheduler = get_linear_schedule_with_warmup(
        #     self.optimizer, warmup_iters, total_iters, decay_rate
        # )
        self.lr_schedulers = None

    def _train(self):
        self.global_step = \
            train(self.task_config, self.global_step, self.gen_net, self.dis_net,
                  self.gen_optimizer, self.dis_optimizer, self.gen_avg_param,
                  self.train_loader, self.lr_schedulers)
        backup_param = copy_params(self.gen_net)
        load_params(self.gen_net, self.gen_avg_param)
        fid_score = validate(self.task_config, self.fid_stat, self.gen_net)

        # mean, std, fid_score = validate(self.task_config, self.fid_stat, self.gen_net)
        load_params(self.gen_net, backup_param)
        return {"fid": fid_score, "early_stop": False}
        # return {"is_mean": mean, "is_std": std, "fid": fid_score, "early_stop": False}

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
        avg_gen_net = deepcopy(self.gen_net)
        load_params(avg_gen_net, self.gen_avg_param)
        torch.save({
            'gen_state_dict': self.gen_net.state_dict(),
            'dis_state_dict': self.dis_net.state_dict(),
            'avg_gen_state_dict': avg_gen_net.state_dict(),
            'gen_optimizer': self.gen_optimizer.state_dict(),
            'dis_optimizer': self.dis_optimizer.state_dict(),
        }, checkpoint_path)
        del avg_gen_net
        return checkpoint_path

    def _restore(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        self.gen_net.load_state_dict(checkpoint['gen_state_dict'])
        self.dis_net.load_state_dict(checkpoint['dis_state_dict'])
        self.gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        self.dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(self.gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        self.gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net
