from copy import deepcopy
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data

def variable(t: torch.Tensor, use_cuda=True, **kwargs):
    if torch.cuda.is_available() and use_cuda:
        t = t.cuda()
    return Variable(t, **kwargs)

class EWC(object):
    def __init__(self, model: nn.Module, dataset: list):

        self.model = model
        self.dataset = dataset

        # named_parameters :  (name, parameter) 조합
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

      # 1) Task A를 학습하고 난 후의 파라미터
        for n, p in deepcopy(self.params).items():
            self._means[n] = variable(p.data)

    def _diag_fisher(self):
      # 2) precision_matrices 0 초기화
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()       # zero_ 0으로 채움
            precision_matrices[n] = variable(p.data)

			# 3) accumulate gradients to calculate fisher
        self.model.eval()
        for input in self.dataset:
            self.model.zero_grad()
            input = variable(input)
            output = self.model(input).view(1, -1)
            label = output.max(1)[1].view(-1)
            loss = F.nll_loss(F.log_softmax(output, dim=1), label)
            loss.backward()

     # 4) diag_fisher
     # diag_fisher는 Task A에게 현재 파라미터가 얼마나 중요한지를 보여주는 지표
     # diag_fisher를 계산하기 위해 Task A로부터 data를 샘플링하고 empirical Fisher information matrix를 계산함
     # 매개변수의 이전 학습 데이터에 대한 상관도(F를 계산할 때 Fisher information matrix 활용/
        # 어떤 random variable의 관측값으로부터 분포의 parameter에 대해 유추할 수 있는 정보의 양)
     # gradients의 제곱을 축적함
     # (parameter ** 2)/len(dataset) 
     # parameter.grad : 어떤 스칼라 값에 대해 parameter에 대해 변화도를 갖는 값/ 모델의 각 매개변수에 대한 gradeint
     # parameter.grad.data : tensor

            for n, p in self.model.named_parameters():
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    # 5) penalty
    # diag_fisher의 diag*(Task A를 학습하고 난 후의 파라미터인 _means와 현재 파라미터의 차이의 제곱)를 통해 penalty 항을 만듦
    
    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
            loss += _loss.sum()
        return loss

def normal_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        loss = F.cross_entropy(output, target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)

def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader,
              ewc: EWC, importance: float):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        # EWC의 loss= (1) Task B에 대한 loss + (2) importance * (3) penalty항: 매개변수의 이전 학습 데이터에 대한 상관도
        # (1) Task B에 대한 loss(현재 학습 데이터에 대한 비용함수) 
        # (2) importance(old task가 새로운 task와 비교하여 얼마나 중요한지를 결정): 람다역할
        # (3) penalty항: diag_fisher의 diag*(Task A를 학습하고 난 후의 파라미터인 _means와 현재 파라미터의 차이의 제곱)
        loss = F.cross_entropy(output, target) + importance * ewc.penalty(model)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)

def test(model: nn.Module, data_loader: torch.utils.data.DataLoader):
    model.eval()
    correct = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        output = model(input)
        correct += (F.softmax(output, dim=1).max(dim=1)[1] == target).data.sum()
    return correct / len(data_loader.dataset)
