import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import WassDistance,MLPDiffusion

class DFHTE(nn.Module):
    def __init__(self, config):
        super(DFHTE, self).__init__()
        self.config = config
        self.in_feature = config['in_feature'] * 2
        self.alpha = config['alpha']
        self.beta = config['beta']
        self.bn = config['bn']
        self.device = config['device']
        self.repre_layer_sizes = config['repre_layer_sizes']
        self.pred_layer_sizes = config['pred_layer_sizes']

        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))
        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.wass_distance = WassDistance(eps=0.01,max_iter=30,device=self.device)
        self.g = MLPDiffusion(config)

    def forward(self, x, t):
        self.repre = self.repre_layers(x)
        treat_output = self.pred_layers_treated(self.repre)
        control_output = self.pred_layers_control(self.repre)
        logits = torch.where(t == 1, treat_output, control_output)
        return logits


    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def calculate_loss(self, x,t,y,w):

        covariates = self.generation(x)
        x_covariates = torch.cat([x, covariates], dim=-1)
        logits = self.forward(x_covariates,t)
        mse_loss = torch.sum(self.mse_loss(logits,y) * w)
        imb_loss = self.wass_distance(self.repre,t,w.squeeze())
        g_loss = self.g.calculate_loss(x)
        loss = mse_loss + self.alpha * imb_loss + self.beta * g_loss

        return loss

    def generation(self, x):

        if self.config['covariate'] == 'gaussian':
            covariates = torch.randn(x.shape).to(self.device)
        elif self.config['covariate'] == 'uniform':
            covariates = torch.randn(x.shape).to(self.device)
            covariates.uniform_(-1.5,1.5)
        else:
            self.g.eval()
            covariates = self.g.generation(x)

        return covariates

    def rct_loss(self,x,t,y):
        covariates = self.generation(x)
        x_covariates = torch.cat([x, covariates], dim=-1)
        logits = self.forward(x_covariates, t)
        mse_loss = torch.sum(self.mse_loss(logits, y))
        return mse_loss