import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np


def _concat(xs):
  return torch.cat([x.view(-1) for x in xs])

class Architect(object):
    def __init__(self, model, network_config):
        self.weight_decay = network_config['weight_decay']
        self.arch_weight_decay = network_config['arch_weight_decay']
        self.clip = network_config['grad_clip']
        self.model = model
        self.network_config = network_config
        self.order = network_config['order']

    def init_arch_optimizer(self, lr):
        self.optimizer = torch.optim.Adam(self.model.get_arch_parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=self.arch_weight_decay)

    def step(self, inputs_train, labels_train, inputs_eval, labels_eval, eta):
        self.optimizer.zero_grad()
        if self.order == 2:
            self._backward_step_unrolled(inputs_train, labels_train, inputs_eval, labels_eval, eta)
        elif self.order == 1:
            loss, outputs = self.model.model_loss_arch(inputs_eval, labels_eval)
            loss.backward()
        else:
            raise Exception('Unrecognized order.')

        self.optimizer.step()

    def _backward_step_unrolled(self, inputs_train, labels_train, inputs_eval, labels_eval, eta):
        unrolled_model = self.compute_unrolled_model(inputs_train, labels_train, eta)
        unrolled_loss, unrolled_outputs = unrolled_model.model_loss_arch(inputs_eval, labels_eval)
        
        unrolled_loss.backward()
        dalpha = [v.grad for v in unrolled_model.get_arch_parameters()]
        # vector = [v.grad for v in unrolled_model.get_parameters()]
        vector = [v.grad.data for v in unrolled_model.get_parameters()]
        implicit_grads = self._hessian_vector_product(vector, inputs_train, labels_train)
        
        for g, ig in zip(dalpha, implicit_grads):
            g.data.sub_(ig.data, alpha=eta)
        
        for v, g in zip(self.model.get_arch_parameters(), dalpha):
            if v.grad is None:
                v.grad = Variable(g.data)
            else:
                v.grad.data.copy_(g.data)
    
    def compute_unrolled_model(self, inputs, labels, eta):
        loss, outputs, correct, total = self.model.model_loss(inputs, labels)
        theta = _concat(self.model.get_parameters()).data
        
        dtheta = _concat(torch.autograd.grad(loss, self.model.get_parameters())).data + self.weight_decay*theta
        unrolled_model = self._construct_model_from_theta(theta.sub(dtheta, alpha=eta))
        return unrolled_model
    
    def _construct_model_from_theta(self, theta):
        model_new = self.model.new()
        model_dict = self.model.state_dict()
        
        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            if k[0] == 'a':
                params[k] = v
            else:
                v_length = np.prod(v.size())
                params[k] = theta[offset: offset+v_length].view(v.size())
                offset += v_length
        
        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        return model_new.cuda()
    
    def _hessian_vector_product(self, vector, inputs, labels, r=0.005):
        R = r / _concat(vector).norm()
        for p, v in zip(self.model.get_parameters(), vector):
            p.data.add_(v, alpha=R)
        loss, outputs = self.model.model_loss_arch(inputs, labels)
        grads_p = torch.autograd.grad(loss, self.model.get_arch_parameters())
        
        for p, v in zip(self.model.get_parameters(), vector):
            p.data.sub_(v, alpha=R)
        loss, outputs = self.model.model_loss_arch(inputs, labels)
        grads_n = torch.autograd.grad(loss, self.model.get_arch_parameters())
        
        # for p, v in zip(self.model.get_parameters(), vector):
        #     p.data.add_(v, alpha=R)
        
        return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
    
