import numpy as np
import torch
from torch.autograd import Variable, grad
import sys
import torch.nn.functional as F

def _concat(xs):
    return torch.cat([x.view(-1) for x in xs])


class Architect(object):

    def __init__(self, model, criterion, args):
        self.network_momentum = args.momentum
        self.network_weight_decay = args.weight_decay
        self.model = model
        self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
                                          lr=args.arch_learning_rate, betas=(0.5, 0.999),
                                          weight_decay=args.arch_weight_decay)
        self.criterion = criterion

    def _train_loss(self, model, input, target, update_type):
        logits = model(input, update_type=update_type)
        loss = self.criterion(logits, target)
        return loss

    def _val_loss(self, model, input, target, update_type, amp_autocast):
        with amp_autocast():
            logits = model(input, update_type)
            loss = self.criterion(logits, target)
        return loss

    def step(self, input_valid, target_valid, epoch, update_type, amp_autocast, use_mlc_loss=True, mlc_loss_weight=1.):
        self.optimizer.zero_grad()
        self._backward_step(input_valid, target_valid, epoch, update_type, amp_autocast, use_mlc_loss=use_mlc_loss, mlc_loss_weight=mlc_loss_weight)
        self.optimizer.step()
        self.optimizer.zero_grad()

    def zero_hot(self, norm_weights):
        # pos = (norm_weights == norm_weights.max(axis=1, keepdims=1))
        valid_loss = torch.log(norm_weights)
        base_entropy = torch.log(torch.tensor(2).float())
        aux_loss = torch.mean(valid_loss) + base_entropy
        return aux_loss

    def mlc_loss(self, arch_param, amp_autocast):

        with amp_autocast():
            # counter = 0.
            aux_loss = 0.
            for _arch_param in arch_param:
                y_pred_neg = _arch_param
                neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
                aux_loss = aux_loss + torch.mean(neg_loss)
                # counter += 1
        return aux_loss # / counter

    def mlc_pos_loss(self, arch_param):
        act_param = F.softmax(arch_param, dim=-1)
        # thr = act_param.min(axis=-1, keepdim=True)[0]*1.2  # try other methods
        thr = act_param.max(axis=-1, keepdim=True)[0]
        y_true = (act_param >= thr)
        arch_param_new = (1 - 2 * y_true) * arch_param
        y_pred_neg = arch_param_new - y_true * 1e12
        y_pred_pos = arch_param_new - ~y_true * 1e12
        neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
        pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
        aux_loss = torch.mean(neg_loss)+torch.mean(pos_loss)
        return aux_loss

    def mlc_loss2(self, arch_param):
        y_pred_neg = arch_param
        neg_loss = torch.log(torch.exp(y_pred_neg))
        aux_loss = torch.mean(neg_loss)
        return aux_loss

    def _backward_step(self, input_valid, target_valid, epoch, update_type, amp_autocast, use_mlc_loss=True, mlc_loss_weight=1.):
        
        cls_loss = self._val_loss(self.model, input_valid, target_valid, update_type, amp_autocast)
        if use_mlc_loss:
            ssr_normal = self.mlc_loss(self.model.arch_parameters(), amp_autocast)
            with amp_autocast():
                weights = 0 + mlc_loss_weight*epoch/100
                loss = cls_loss + weights*ssr_normal
        else:
            loss = cls_loss
        # loss = self._val_loss(self.model, input_valid, target_valid)
        loss.backward()

    def _construct_model_from_theta(self, theta):
        model_new = self.model.new()
        model_dict = self.model.state_dict()

        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset: offset + v_length].view(v.size())
            offset += v_length

        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        return model_new.cuda()

    def _hessian_vector_product(self, vector, input, target, r=1e-2):
        R = r / _concat(vector).norm()
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)
        loss = self._train_loss(self.model, input=input, target=target)
        grads_p = torch.autograd.grad(loss, self.model.arch_parameters())

        for p, v in zip(self.model.parameters(), vector):
            p.data.sub_(2 * R, v)
        loss = self._train_loss(self.model, input=input, target=target)
        grads_n = torch.autograd.grad(loss, self.model.arch_parameters())

        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)

        return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
