import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import MMDDistance
class RITE_MMD(AbstractModel):
    def __init__(self, config,dataset):
        super(RITE_MMD, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]

        self.alpha = self.config['alpha']
        self.beta = self.config['beta']
        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']

        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))
        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.regularization_mse = nn.MSELoss(reduction='mean')
        self.mmd_distance = MMDDistance()


    def forward(self, x, t):
        self.repre = self.repre_layers(x)
        y = torch.where(t == 1, self.pred_layers_treated(self.repre), self.pred_layers_control(self.repre))
        return y

    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def hilbert_norm(self,repre):

        hilbert_repre = torch.mul(repre,repre)
        return torch.sum(torch.sum(hilbert_repre,dim=-1) - 1)

    def calculate_loss(self, x,t,y,w):

        pred = self.forward(x,t)
        mse_loss = torch.sum(self.mse_loss(pred,y) * w)
        imb_loss = self.mmd_distance(self.repre,t,w)
        hilbert_loss = self.hilbert_norm(self.repre)
        loss = mse_loss + self.alpha * imb_loss+self.regularization_loss()*self.beta + self.beta * hilbert_loss
        return loss

    def regularization_loss(self):
        regular_term = None
        for param in self.parameters():
            if len(param.shape) == 1:
                cur_loss = torch.sum(param * param)-1
                regular_term = cur_loss if regular_term is None else cur_loss + regular_term

            elif len(param.shape) == 2:

                cur_loss = self.regularization_mse(torch.matmul(param.T,param),torch.eye(n=param.shape[1]).to(self.device))
                regular_term = cur_loss if regular_term is None else cur_loss + regular_term

        return regular_term

