import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import ot
from torch.nn.init import normal_

class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """

    def __init__(self, eps, max_iter, reduction='none', device='cpu'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.device = device

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]
        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / (x_points + 1e-3)).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / (y_points + 1e-3)).squeeze()
        # print('nu:',nu.shape,batch_size,y_points,y.shape)
        mu = mu.to(self.device)
        nu = nu.to(self.device)
        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \\epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

class ScaleMixtureGaussian(object):
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.gaussian1 = torch.distributions.Normal(0, sigma1)
        self.gaussian2 = torch.distributions.Normal(0, sigma2)

    def log_prob(self, input):
        prob1 = torch.exp(self.gaussian1.log_prob(input))
        prob2 = torch.exp(self.gaussian2.log_prob(input))
        return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()

class Gaussian(object):
    def __init__(self, mu, rho):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.normal = torch.distributions.Normal(0, 1)
        self.device = torch.device("cuda" if torch.cuda.is_available()  else "cpu")

    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))

    def sample(self):
        epsilon = self.normal.sample(self.rho.size()).to(self.device)
        return self.mu + self.sigma * epsilon

    def log_prob(self, input):
        return (-math.log(math.sqrt(2 * math.pi))
                - torch.log(self.sigma)
                - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5, -4))
        self.weight = Gaussian(self.weight_mu, self.weight_rho)
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5, -4))


        self.bias = Gaussian(self.bias_mu, self.bias_rho)
        # Prior distributions
        self.PI = 0.5
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.SIGMA_1 = torch.FloatTensor([math.exp(-0)]).to(self.device)
        self.SIGMA_2 = torch.FloatTensor([math.exp(-6)]).to(self.device)

        self.weight_prior = ScaleMixtureGaussian(self.PI, self.SIGMA_1, self.SIGMA_2)
        self.bias_prior = ScaleMixtureGaussian(self.PI, self.SIGMA_1, self.SIGMA_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):

        if self.training or sample:
            weight = self.weight.mu
            bias = self.bias.mu
        else:
            weight = self.weight.sample()
            bias = self.bias.sample()
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        return F.linear(input, weight, bias)

def get_linear_layers(in_dim, layer_sizes, bn = False, activation = None):
    linear_layers = map(nn.Linear, [in_dim] + layer_sizes, layer_sizes)
    tmp = [linear_layers]
    if bn:
        bns = [nn.BatchNorm1d(dim) for dim in layer_sizes]
        tmp += [bns]
    if activation is not None:
        activations = [activation() for _ in range(len(layer_sizes))]
        tmp += [activations]
    tmp = zip(*tmp)
    return [module for pair in tmp for module in pair]

class MLPLayers(nn.Module):

    def __init__(
        self, layers, dropout=0.0, activation="relu", bn=False, init_method=None
    ):
        super(MLPLayers, self).__init__()
        self.layers = layers
        self.dropout = dropout
        self.activation = activation
        self.use_bn = bn
        self.init_method = init_method

        mlp_modules = []
        for idx, (input_size, output_size) in enumerate(
            zip(self.layers[:-1], self.layers[1:])
        ):
            mlp_modules.append(nn.Dropout(p=self.dropout))
            mlp_modules.append(nn.Linear(input_size, output_size))
            if self.use_bn:
                mlp_modules.append(nn.BatchNorm1d(num_features=output_size))
            activation_func = activation_layer(self.activation, output_size)
            if activation_func is not None:
                mlp_modules.append(activation_func)

        self.mlp_layers = nn.Sequential(*mlp_modules)
        if self.init_method is not None:
            self.apply(self.init_weights)

    def init_weights(self, module):
        # We just initialize the module with normal distribution as the paper said
        if isinstance(module, nn.Linear):
            if self.init_method == "norm":
                normal_(module.weight.data, 0, 0.01)
            if module.bias is not None:
                module.bias.data.fill_(0.0)

    def forward(self, input_feature):
        return self.mlp_layers(input_feature)

def set_color(log, color, highlight=True):
    color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"]
    try:
        index = color_set.index(color)
    except:
        index = len(color_set) - 1
    prev_log = "\033["
    if highlight:
        prev_log += "1;3"
    else:
        prev_log += "0;3"
    prev_log += str(index) + "m"
    return prev_log + log + "\033[0m"

def get_linear_layers(in_dim, layer_sizes, bn = False, activation = None):
    linear_layers = map(nn.Linear, [in_dim] + layer_sizes, layer_sizes)
    tmp = [linear_layers]
    if bn:
        bns = [nn.BatchNorm1d(dim) for dim in layer_sizes]
        tmp += [bns]
    if activation is not None:
        activations = [activation() for _ in range(len(layer_sizes))]
        tmp += [activations]
    tmp = zip(*tmp)
    return [module for pair in tmp for module in pair]

# Adapted from https://github.com/dfdazac/wassdistance
class WassDistance(nn.Module):

    def __init__(self, eps, max_iter, device=None, reduction='none'):
        super(WassDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.device = device

    def forward(self, repre,treats,probs):
        # print(repre.shape,treats.shape,probs.shape)
        x = repre[(treats==1).squeeze(),:]
        y = repre[(treats==0).squeeze(),:]
        x_weights = probs[(treats==1).squeeze()]
        y_weights = probs[(treats==0).squeeze()]
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x_points == 0 or y_points == 0:
            return 1e10

        mu = (x_weights.clone().detach().requires_grad_(False)/x_weights.sum()).to(self.device)
        nu = (y_weights.clone().detach().requires_grad_(False)/y_weights.sum()).to(self.device)

        u = torch.zeros_like(mu, device=self.device)
        v = torch.zeros_like(nu, device=self.device)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \\epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

class MMDDistance(nn.Module):

    def __init__(self,kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super(MMDDistance, self).__init__()
        self.kernel_mul = kernel_mul
        self.kernel_num = kernel_num
        self.fix_sigma=fix_sigma

    def forward(self, repre,treats,probs):

        source = repre[(treats == 1).squeeze(), :]
        target = repre[(treats == 0).squeeze(), :]
        n = int(source.size()[0])
        m = int(target.size()[0])

        kernels = self.guassian_kernel(source, target)
        XX = kernels[:n, :n]
        YY = kernels[n:, n:]
        XY = kernels[:n, n:]
        YX = kernels[n:, :n]

        XX = torch.div(XX, n * n).sum(dim=1).view(1, -1)  # K_ss矩阵，Source<->Source
        XY = torch.div(XY, -n * m).sum(dim=1).view(1, -1)  # K_st矩阵，Source<->Target

        YX = torch.div(YX, -m * n).sum(dim=1).view(1, -1)  # K_ts矩阵,Target<->Source
        YY = torch.div(YY, m * m).sum(dim=1).view(1, -1)  # K_tt矩阵,Target<->Target

        loss = (XX + XY).sum() + (YX + YY).sum()
        return loss

    def guassian_kernel(self,source, target):

        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)  # 合并在一起

        total0 = total.unsqueeze(0).expand(int(total.size(0)),
                                           int(total.size(0)),
                                           int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)),
                                           int(total.size(0)),
                                           int(total.size(1)))
        L2_distance = ((total0 - total1) ** 2).sum(2)  # 计算高斯核中的|x-y|

        # 计算多核中每个核的bandwidth
        if self.fix_sigma:
            bandwidth = self.fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
        bandwidth /= self.kernel_mul ** (self.kernel_num // 2)
        bandwidth_list = [bandwidth * (self.kernel_mul ** i) for i in range(self.kernel_num)]

        # 高斯核的公式，exp(-|x-y|/bandwith)
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for \
                      bandwidth_temp in bandwidth_list]

        return sum(kernel_val)  # 将多个核合并在一起

def mmd_distance(repre,treats,probs):
    '''Linear MMD'''

    if sum(treats) == 0:
        x = 0
        y = repre[(treats == 0).squeeze(), :]

    else:

        x = repre[(treats == 1).squeeze(), :]
        y = repre[(treats == 0).squeeze(), :]

    mean_control = torch.mean(x,dim=0)
    mean_treated = torch.mean(y,dim=0)
    # if sum(treats==1) == 0: mean_treated = 0

    mmd = torch.sum(torch.square(2.0 * probs * mean_treated - 2.0 * (1.0 - probs) * mean_control))

    return mmd

def cal_wass(rep_0, rep_1, out_0, out_1, t, yf, device, hparams):

    dist = hparams['ot_scale'] * ot.dist(rep_0, rep_1)

    if hparams['gamma'] > 0:

        pred_0_cf = out_1[t == 0]  # predicted outcome for samples in control group given t == 1
        pred_1_cf = out_0[t == 1]  # predicted outcome for samples in treated group given t == 0
        yf_1 = yf[t == 1]
        yf_0 = yf[t == 0]

        dist_10 = ot.dist(pred_0_cf, yf_1)
        dist_01 = ot.dist(yf_0, pred_1_cf)

        dist += hparams['gamma'] * (dist_01 + dist_10)

    if hparams['ot'] == 'ot':

        gamma = ot.sinkhorn(
            torch.ones(len(rep_0), device=device) / len(rep_0),
            torch.ones(len(rep_1), device=device) / len(rep_1),
            dist.detach(),
            reg=hparams.get('epsilon'),
            stopThr=1e-4)

    elif hparams['ot'] == 'uot':

        gamma = ot.unbalanced.sinkhorn_unbalanced(
            torch.ones(len(rep_0), device=device) / len(rep_0),
            torch.ones(len(rep_1), device=device) / len(rep_1),
            dist.detach(),
            reg=hparams.get('epsilon'),
            stopThr=1e-6,
            reg_m=hparams.get('kappa'))
    else:
        print("ERROR: The hparams.ot is not correctly defined")

    loss_wass = torch.sum(gamma * dist)
    return loss_wass

class MLPDiffusion(nn.Module):

    def __init__(self, config,dataset):

        super(MLPDiffusion, self).__init__()

        self.config = config
        self.dataset = dataset
        self.device = config['device']
        self.num_units = self.config['num_units']
        self.n_steps = self.config['n_steps']
        self.in_feature = self.dataset.size[1]

        betas = torch.linspace(-6, 6, self.n_steps).to(self.device)
        self.betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
        self.alphas = 1 - self.betas
        self.alphas_prod = torch.cumprod(self.alphas, 0)

        self.alphas_prod_p = torch.cat([torch.tensor([1]).float().to(self.device), self.alphas_prod[:-1]], 0)
        # alphas_prod开根号
        self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod)
        # 之后公式中要用的
        self.one_minus_alphas_bar_log = torch.log(1 - self.alphas_prod)
        self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod)

        self.linears = nn.ModuleList(
            [
                nn.Linear(self.in_feature * 2, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.in_feature),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(self.n_steps, self.num_units),
                nn.Embedding(self.n_steps, self.num_units),
                nn.Embedding(self.n_steps, self.num_units),
            ]
        )

        self.encoder = nn.Sequential(
                nn.Linear(self.in_feature + 2, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.num_units),
                nn.ReLU(),
                nn.Linear(self.num_units, self.in_feature),
        )
        self.mean = nn.Linear(self.in_feature,1)
        self.variance = nn.Linear(self.in_feature,1)

    def generation_loss(self,x,treatment,y):
        new_x = torch.cat([x,treatment,y],1)
        u = self.encoder(new_x)
        mean = self.mean(u).squeeze(dim=-1)
        variance = self.variance(u).squeeze(dim=-1)
        loss = torch.sum(torch.log(1./variance)+ 1./2 * (variance ** 2 + mean ** 2))
        return mean,variance,loss

    def generation_z(self,x,mean,variance):

        len = x.shape[-1]
        e = torch.randn_like(x).to(self.device)
        variance = variance.unsqueeze(-1).repeat(1,len)
        mean = mean.unsqueeze(-1).repeat(1,len)
        return e * variance + mean


    def forward(self, x,treatment,y,t):
        mean, variance, g_loss = self.generation_loss(x,treatment,y)
        z = self.generation_z(x,mean,variance)
        x = torch.cat([x,z],dim=-1)
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t.squeeze(-1))
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)

        x = self.linears[-1](x)

        return x

    def calculate_loss(self,x,treatment,y):

        batch_size = x.shape[0]

        if batch_size % 2 == 0:
            t = torch.randint(0, self.n_steps, size=(batch_size // 2,)).to(self.device)
            t = torch.cat([t, self.n_steps - 1 - t], dim=0)  # t的形状（bz）
        else:
            t = torch.randint(0, self.n_steps, size=(batch_size // 2 + 1,)).to(self.device)
            t = torch.cat([t[:-1], self.n_steps - 1 - t], dim=0)  # t的形状（bz）

        t = t.unsqueeze(-1)  # t的形状（bz,1）
        # print(t)
        # x0的系数，根号下(alpha_bar_t)
        a = self.alphas_bar_sqrt[t]

        # eps的系数,根号下(1-alpha_bar_t)
        aml = self.one_minus_alphas_bar_sqrt[t]

        # 生成随机噪音eps
        e = torch.randn_like(x).to(self.device)

        # print(x.shape,t.shape)
        x = x * a + e * aml

        # 送入模型，得到t时刻的随机噪声预测值
        output = self.forward(x,treatment,y,t)

        # 与真实噪声一起计算误差，求平均值
        return (e - output).square().sum()

    # forward process
    def q_x(self,x_0, t):
        noise = torch.randn_like(x_0).to(self.device)
        alphas_t = self.alphas_bar_sqrt[t]
        alphas_1_m_t = self.one_minus_alphas_bar_sqrt[t]

        return (alphas_t * x_0 + alphas_1_m_t * noise)

    # generating process
    @torch.no_grad()
    def p_sample(self, x, t):
        self.eval()
        t = torch.tensor([t]).to(self.device)
        coeff = self.betas[t] / self.one_minus_alphas_bar_sqrt[t]
        eps_theta = self.forward(x, t)
        # 得到均值
        mean = (1 / (1 - self.betas[t]).sqrt()) * (x - (coeff * eps_theta))
        z = torch.randn_like(x).to(self.device)
        sigma_t = self.betas[t].sqrt()
        # 得到sample的分布
        sample = mean + sigma_t * z

        return (sample)

    # backward process
    def p_sample_loop(self, x):
        """从x[T]recover x[T-1]、x[T-2]|...x[0]"""
        cur_x = torch.randn(x.shape).to(self.device)
        x_seq = [cur_x]
        for i in reversed(range(self.n_steps)):
            cur_x = self.p_sample(cur_x, i)
            x_seq.append(cur_x)

        return x_seq

    def generation(self, x):
        """x[0]"""
        cur_x = torch.randn(x.shape).to(self.device)
        x_seq = [cur_x]
        for i in reversed(range(self.n_steps)):
            cur_x = self.p_sample(cur_x, i)
            x_seq.append(cur_x)
        return x_seq[-1]

    @torch.no_grad()
    def get_eta(self,x,treatment,y):
        mean, variance, g_loss = self.generation_loss(x,treatment,y)
        z = self.generation_z(x, mean, variance)
        return z

    @torch.no_grad()
    def get_noise_eta(self, x, treatment, y,a):

        mean, variance, g_loss = self.generation_loss(x, treatment, y)
        z = self.generation_z(x, mean+a, variance)

        return z