import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from ..utils.diff_utils import _concat

import pdb

class GNNPPIController(object):

    def __init__(self, config):
        self.config = config
        self.build_sane_settings()

    def build_sane_settings(self):
        self.network_momentum = self.config.momentum
        self.network_weight_decay = self.config.weight_decay
        self.unrolled = self.config.unrolled

        # subnet
        self.subnet = self.config.get('subnet', None)

    # 在solver中进行model的初始化
    def set_supernet(self, model):
        self.model = model

    # 在solver中进行logger的初始化
    def set_logger(self, logger):
        self.logger = logger

    # 在solver中进行目标函数的初始化
    def set_criterion(self, criterion):
        self.criterion = criterion

    # 依赖于model的初始化
    def init_optimizer(self):
        self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
            lr=self.config.arch_learning_rate, betas=(0.5, 0.999), weight_decay=self.config.arch_weight_decay)

    def step(self, train_data, val_data, eta, network_optimizer):
        if self.subnet != None:
            return

        self.optimizer.zero_grad()
        if self.unrolled:
            self._backward_step_unrolled(train_data, val_data, eta, network_optimizer)
        else:
            self._backward_step(val_data)
        self.optimizer.step()

    def _backward_step(self, val_data):
        for data in val_data:
            inp, target = data, Variable(data.y)
            logit = self.model(inp)
            loss = self.criterion(logit, target)

            if getattr(self.config, 'sparse', False):
                p = getattr(self.config.sparse, 'norm', 2)
                _lambda = getattr(self.config.sparse, 'lambda', 0.001)

                if getattr(self.config.sparse, 'na_sparse', True):
                    na_reg_loss = _lambda * torch.norm(self.model.na_weights, p=p)
                    loss += na_reg_loss

                if getattr(self.config.sparse, 'sc_sparse', True):
                    sc_reg_loss = _lambda * torch.norm(self.model.sc_weights, p=p)
                    loss += sc_reg_loss

                if getattr(self.config.sparse, 'la_sparse', True):
                    la_reg_loss = _lambda * torch.norm(self.model.la_weights, p=p)
                    loss += la_reg_loss
            loss.backward()

    def _backward_step_unrolled(self, train_data, val_data, eta, network_optimizer):
        # one-step update, get w'
        unrolled_model = self._compute_unrolled_model(train_data, eta, network_optimizer)

        inp, target = val_data['input'], val_data['target']
        target = target.squeeze().long()
        logit = unrolled_model(inp)
        unrolled_loss = self.criterion(logit, target)

        # one-step update for \alpha
        unrolled_loss.backward()

        # compute hessian vector product
        dalpha = [v.grad for v in unrolled_model.arch_parameters()] # L_vali w.r.t alpha
        vector = [v.grad.data for v in unrolled_model.parameters()] # gradient, L_train w.r.t w, double check the model construction
        implicit_grads = self._hessian_vector_product(vector, train_data)
        for g, ig in zip(dalpha, implicit_grads):
            g.data.sub_(eta, ig.data)

        # update alpha, which is the ultimate goal of this func, also the goal of the second-order darts
        for v, g in zip(self.model.arch_parameters(), dalpha):
            if v.grad is None:
                v.grad = Variable(g.data)
            else:
                v.grad.data.copy_(g.data)

    def _compute_unrolled_model(self, train_data, eta, network_optimizer):
        # get train loss
        inp, target = train_data['input'], train_data['target']
        target = target.squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target) 

        theta = _concat(self.model.parameters()).data# w
        try:
            moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
        except:
            moment = torch.zeros_like(theta)
        dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta #gradient, L2 norm
        unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) # one-step update, get w' for Eq.7 in the paper
        return unrolled_model

    def _construct_model_from_theta(self, theta):
        model_new = self.model.new()
        model_dict = self.model.state_dict()

        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset: offset+v_length].view(v.size())
            offset += v_length

        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        # TODO to_device()
        return model_new

    def _hessian_vector_product(self, vector, train_data, r=1e-2):
        R = r / _concat(vector).norm()
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v) # R * d(L_val/w', i.e., get w^+

        # get train loss, # loss = self.model._loss(data, is_valid=False) # train loss
        inp, target = train_data['input'], train_data['target'].squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target)

        grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) # d(L_train)/d_alpha, w^+

        for p, v in zip(self.model.parameters(), vector):
            p.data.sub_(2*R, v) # get w^-, need to subtract 2 * R since it has add R

        # get train loss, # loss = self.model._loss(data, is_valid=False) # train loss
        inp, target = train_data['input'], train_data['target'].squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target)

        grads_n = torch.autograd.grad(loss, self.model.arch_parameters())# d(L_train)/d_alpha, w^-

        #reset to the orignial w, always using the self.model, i.e., the original model
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)

        return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]

    def sample_subnet_settings(self, sample_mode='random', subnet_settings=None):
        if subnet_settings is None:
            subnet_settings = self.model.sample_active_subnet(sample_mode=sample_mode)
        return subnet_settings

    def build_active_subnet(self, subnet_settings):
        return self.model.build_active_subnet(subnet_settings)

    def get_subnet_weight(self, subnet_settings=None):
        subnet = self.model.sample_active_subnet_weights(subnet_settings)
        return subnet

    def get_all_subnets(self):
        import os
        import json
        import copy
        import itertools

        count = 0
        subnet_table = {}
        kwargs_configs = copy.deepcopy(self.model.dynamic_settings)
        self.logger.info(json.dumps(kwargs_configs))

        for kwarg in itertools.product(*kwargs_configs.values()):
            # kwarg_config = dict(zip(kwargs_configs.keys(), kwarg))
            subnet_settings = '||'.join(kwarg)
            self.logger.info(subnet_settings)
            subnet_table[count] = {'genotype': subnet_settings}
            count += 1
            if count > 1000000:
                self.logger.info('total subnet number surpass 1 million')
                return

        self.logger.info('total subnet number {}'.format(count))
        return subnet_table

    def get_subnet(self, sample_mode='random', subnet_settings=None):
        if subnet_settings is None:
            subnet_setting = self.model.sample_subnet(sample_mode=sample_mode)
        return subnet_setting

    def sample_subnets(self):
        import json
        import time
        import random

        assert self.subnet is not None
        self.subnet_count = self.subnet.get('subnet_count', 500)
        self.subnet_sample_mode = self.subnet.get('subnet_sample_mode', 'random')

        if self.subnet_sample_mode == 'traverse':
            return self.get_all_subnets()
        else:
            # print out the size of search space
            self.get_all_subnets()

        subnet_table = {}
        self.logger.info('subnet count {}'.format(self.subnet_count))
        count = 0
        seed = int(time.time() * 10000) % 10000
        self.logger.info('seed {}'.format(seed))

        while count < self.subnet_count:
            seed += 1
            random.seed(seed)
            subnet_setting = self.get_subnet()
            subnet_table[count] = {'genotype': subnet_setting}
            count += 1

        self.logger.info(json.dumps(subnet_table))
        return subnet_table