import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import WassDistance

def squared_euclidean_distance(x):
    # Compute the squared Euclidean distance matrix efficiently
    square = torch.sum(x ** 2, dim=1, keepdim=True)
    distance = square + square.T - 2 * torch.mm(x, x.T)
    return distance

def gaussian_kernel_matrix(x, sigma=10):
    # Compute the Gaussian kernel matrix
    sq_dist = squared_euclidean_distance(x)
    return torch.exp(-sq_dist / (2 * sigma ** 2))

def epanechnikov_kernel_matrix(x, sigma=2.0):
    # Compute the Epanechnikov kernel matrix
    sq_dist = squared_euclidean_distance(x)
    return torch.clamp(1 - sq_dist / (sigma ** 2), min=0)

class Ours1(AbstractModel):
    def __init__(self, config,dataset):
        super(Ours1, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.kernel = self.config['kernel']
        self.alpha = self.config['alpha']
        self.theta = self.config['theta']
        self.beta = self.config['beta']
        self.phi = self.config['phi']

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']
        self.prop_layer_sizes = self.config['prop_layer_sizes']
        self.kernel_layer_sizes = self.config['kernel_layer_sizes']

        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.kernel_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.kernel_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))
        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))

        self.prop_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.prop_layer_sizes,self.bn,nn.ReLU)))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.wass_distance = WassDistance(eps=0.01,max_iter=30,device=self.device)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, t):

        t = t.squeeze()

        x_repre = self.repre_layers(x)
        y = torch.where(t == 0, self.pred_layers_treated(x_repre).squeeze(), self.pred_layers_control(x_repre).squeeze())
        return y

    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def calculate_loss(self, x,t,y,w):
        t = t.squeeze()
        y = y.squeeze()
        x_repre = self.repre_layers(x)     
        ps_1 = torch.clip(self.sigmoid(self.prop_layers(x)).squeeze(), 0.1, 0.9) # Tune the clip ratio
        ps_0 = 1 - ps_1
        ps = torch.cat((ps_0.reshape([-1, 1]), ps_1.reshape([-1, 1])), dim = 1)

        taus_pred = self.forward(x, t)
        if self.kernel != 0:
          kernel = epanechnikov_kernel_matrix(self.kernel_layers(x), sigma=3.0)
        else:
          kernel = gaussian_kernel_matrix(self.kernel_layers(x), sigma=3.0)

        treatment_mask = (t[:,None] == t[None,:]).int().float()
        
        rev_treatment_mask = 1.0 - treatment_mask

        diff_mat_tau_y = torch.abs(taus_pred[:, None] - y[None, :]).squeeze()
        diff_mat_y_y = torch.sign(y[:, None] - y[None, :]).squeeze()
        
        term1 = (diff_mat_tau_y / ps[:, 1-t]) * rev_treatment_mask * kernel

        term1 = term1.sum(dim=1) # Sum of all elements in each row

        term2 = (diff_mat_y_y / ps[:, t]) * treatment_mask * kernel # n x n

        sums2 = term2.sum(dim=1) # Sum of all elements in each row

        term2 = sums2 * taus_pred

        loss = torch.mean(term1 + self.alpha * term2) # Tune the weight alpha

        return loss    
