import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable

from ..utils.diff_utils import _concat

class GNNController(object):

    def __init__(self, config):
        self.config = config
        self.build_sane_settings()

    def build_sane_settings(self):
        self.network_momentum = self.config.momentum
        self.network_weight_decay = self.config.weight_decay

        # subnet
        self.subnet = self.config.get('subnet', None)

    # 在solver中进行model的初始化
    def set_supernet(self, model):
        self.model = model

    # 在solver中进行logger的初始化
    def set_logger(self, logger):
        self.logger = logger

    # 在solver中进行目标函数的初始化
    def set_criterion(self, criterion):
        self.criterion = criterion

    # 依赖于model的初始化
    def init_optimizer(self):
        self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
            lr=self.config.arch_learning_rate, betas=(0.5, 0.999), weight_decay=self.config.arch_weight_decay)

    def step(self, train_data, val_data, eta, network_optimizer, unrolled):
        if self.subnet != None:
            return

        self.optimizer.zero_grad()
        if unrolled:
            self._backward_step_unrolled(train_data, val_data, eta, network_optimizer)
        else:
            self._backward_step(val_data)
        self.optimizer.step()

    def _backward_step(self, val_data):
        # inp, target = val_data['input'], val_data['target']
        for data in val_data:
            inp, target = data, Variable(data.y)
            # target = target.squeeze().long()
            logit = self.model(inp)
            loss = self.criterion(logit, target)
            loss.backward()

    def _backward_step_unrolled(self, train_data, val_data, eta, network_optimizer):
        # one-step update, get w'
        unrolled_model = self._compute_unrolled_model(train_data, eta, network_optimizer)

        inp, target = val_data['input'], val_data['target']
        target = target.squeeze().long()
        logit = unrolled_model(inp)
        unrolled_loss = self.criterion(logit, target)

        # one-step update for \alpha
        unrolled_loss.backward()

        # compute hessian vector product
        dalpha = [v.grad for v in unrolled_model.arch_parameters()] # L_vali w.r.t alpha
        vector = [v.grad.data for v in unrolled_model.parameters()] # gradient, L_train w.r.t w, double check the model construction
        implicit_grads = self._hessian_vector_product(vector, train_data)
        for g, ig in zip(dalpha, implicit_grads):
            g.data.sub_(eta, ig.data)

        # update alpha, which is the ultimate goal of this func, also the goal of the second-order darts
        for v, g in zip(self.model.arch_parameters(), dalpha):
            if v.grad is None:
                v.grad = Variable(g.data)
            else:
                v.grad.data.copy_(g.data)

    def _compute_unrolled_model(self, train_data, eta, network_optimizer):
        # get train loss
        inp, target = train_data['input'], train_data['target']
        target = target.squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target) 

        theta = _concat(self.model.parameters()).data# w
        try:
            moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
        except:
            moment = torch.zeros_like(theta)
        dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta #gradient, L2 norm
        unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) # one-step update, get w' for Eq.7 in the paper
        return unrolled_model

    def _construct_model_from_theta(self, theta):
        model_new = self.model.new()
        model_dict = self.model.state_dict()

        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset: offset+v_length].view(v.size())
            offset += v_length

        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        # TODO to_device()
        return model_new

    def _hessian_vector_product(self, vector, train_data, r=1e-2):
        R = r / _concat(vector).norm()
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v) # R * d(L_val/w', i.e., get w^+

        # get train loss, # loss = self.model._loss(data, is_valid=False) # train loss
        inp, target = train_data['input'], train_data['target'].squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target)

        grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) # d(L_train)/d_alpha, w^+

        for p, v in zip(self.model.parameters(), vector):
            p.data.sub_(2*R, v) # get w^-, need to subtract 2 * R since it has add R

        # get train loss, # loss = self.model._loss(data, is_valid=False) # train loss
        inp, target = train_data['input'], train_data['target'].squeeze().long()
        logit = self.model(inp)
        loss = self.criterion(logit, target)

        grads_n = torch.autograd.grad(loss, self.model.arch_parameters())# d(L_train)/d_alpha, w^-

        #reset to the orignial w, always using the self.model, i.e., the original model
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)

        return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]

    def sample_subnet_settings(self, sample_mode='random', subnet_settings=None):
        if subnet_settings is None:
            subnet_settings = self.model.sample_active_subnet(sample_mode=sample_mode)
        return subnet_settings

    def build_active_subnet(self, subnet_settings):
        return self.model.build_active_subnet(subnet_settings)

    def get_subnet_weight(self, subnet_settings=None):
        subnet = self.model.sample_active_subnet_weights(subnet_settings)
        return subnet