import torch
import torch.nn as nn
from utils import weights_init
import numpy as np
class LinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
        super(LinearBlock, self).__init__()
        use_bias = True
        # initialize fully connected layer
        if norm == 'sn':
            self.fc = nn.utils.spectral_norm(nn.Linear(input_dim, output_dim, bias=use_bias))
        elif norm == 'wn':
            self.fc = nn.utils.weight_norm(nn.Linear(input_dim, output_dim, bias=use_bias))
        else:
            self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm1d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm1d(norm_dim)
        elif norm == 'ln':
            self.norm = nn.LayerNorm(norm_dim)
        elif norm == 'none' or norm == 'sn' or norm == 'wn':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'elu':
            self.activation = nn.ELU()
        elif activation == 'swish':
            self.activation = nn.SiLU()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

    def forward(self, x):
        out = self.fc(x)
        if self.norm:
            out = self.norm(out)
        if self.activation:
            out = self.activation(out)
        return out

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):

        super(MLP, self).__init__()
        self.model = []
        self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
        for i in range(n_blk - 2):
            self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
        self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x.view(x.size(0), -1))

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=True)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class DiscreteTreatmentParentNet(nn.Module):
    def __init__(self, model_config, data_config):
        """
        We use different models for each treatment following existing methods
        We take in covariate, treatment and targets and outputs the quantile value for each sample in [0,1]
        :param model_config:
        :param data_config:
        """
        super().__init__()
        self.phis = nn.ModuleList()
        for _ in range(data_config.discrete_range):
            self.phis.append(LinearBlock(data_config.target_dim, model_config.target_embed_dim, model_config.norm, model_config.activ))
        self.model = nn.ModuleList()
        for _ in range(data_config.discrete_range):
            self.model.append(MLP(data_config.covariate_dim, model_config.target_embed_dim,  model_config.hidden_dim,
                                  model_config.n_layers, model_config.norm, model_config.activ))

        self.net = MLP(model_config.target_embed_dim, 1, model_config.hidden_dim, model_config.n_layers,
                       model_config.norm, model_config.activ)
        self.phis.apply(weights_init(model_config.init_type))
        self.model.apply(weights_init(model_config.init_type))

    def forward(self, covariate, treatment, targets):
        out = []
        for layer in self.phis:
            out += [layer(targets)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(treatment.size(0))).to(covariate.device)
        h = out[idx, treatment.long().view(-1)]  # (batch, style_dim)

        out = []
        for layer in self.model:
            #out += [layer(torch.cat([covariate, h],1))]
            out += [layer(covariate)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(treatment.size(0))).to(covariate.device)
        outs = out[idx, treatment.long().view(-1)]  # (batch, style_dim)

        outs = self.net(outs + h)
        outs = torch.sigmoid(outs).view(-1)
        return outs

import torch.nn.functional as F
class CBN(nn.Module):
    def __init__(self, output_size, input_size, which_linear=nn.Linear, eps=1e-5, momentum=0.1,
                 cross_replica=False, mybn=False, norm_style='bn', ):
        super(CBN, self).__init__()
        self.output_size, self.input_size = output_size, input_size
        # Prepare gain and bias layers
        self.gain = which_linear(input_size, output_size)
        self.bias = which_linear(input_size, output_size)
        # epsilon to avoid dividing by 0
        self.eps = eps
        # Momentum
        self.momentum = momentum
        # Use cross-replica batchnorm?
        self.cross_replica = cross_replica
        # Use my batchnorm?
        self.mybn = mybn
        # Norm style?
        self.norm_style = norm_style

        self.register_buffer('stored_mean', torch.zeros(output_size))
        self.register_buffer('stored_var', torch.ones(output_size))

    def forward(self, x, y):
        # Calculate class-conditional gains and biases
        gain = (1 + self.gain(y)).view(y.size(0), -1)
        bias = self.bias(y).view(y.size(0), -1)
        # If using my batchnorm
        out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
                                   self.training, 0.1, self.eps)
        return out * gain + bias

    def extra_repr(self):
        s = 'out: {output_size}, in: {input_size},'
        s += ' cross_replica={cross_replica}'
        return s.format(**self.__dict__)

class ConditionalMLP(nn.Module):
    def __init__(self, input_dim, output_dim, cond_dim, hidden_dim=100):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = CBN(hidden_dim, cond_dim)
        self.act1 = nn.ELU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = CBN(hidden_dim, cond_dim)
        self.act2 = nn.ELU()
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.bn3 = CBN(hidden_dim, cond_dim)
        self.act3 = nn.ELU()
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.bn4 = CBN(hidden_dim, cond_dim)
        self.act4 = nn.ELU()
        self.fc5 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, y):
        x1 = self.fc1(x)
        x2 = self.bn1(x1, y)
        x3 = self.act1(x2)
        x4 = self.fc2(x3)
        x5 = self.bn2(x4, y)
        x6 = self.act2(x5)
        x7 = self.fc3(x6)
        x8 = self.bn3(x7, y)
        x9 = self.act3(x8)
        x10 = self.fc4(x9)
        x11 = self.bn4(x10, y)
        x12 = self.act4(x11)
        x13 = self.fc5(x12)
        return x13



class ConcatNet(nn.Module):
    def __init__(self, covariate_dim, tau_embed_dim, cond_dim, n_layers, hidden_dim, norm, activ, out_dim):
        super().__init__()
        self.fcs = nn.ModuleList()
        self.tau_fcs = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.fcs.append(LinearBlock(covariate_dim+tau_embed_dim, hidden_dim, norm, activ))
        self.tau_fcs.append(LinearBlock(cond_dim, tau_embed_dim, norm, activ))
        for _ in range(n_layers-2):
            self.fcs.append(LinearBlock(hidden_dim + hidden_dim, hidden_dim, norm, activ))
            self.tau_fcs.append(LinearBlock(cond_dim, hidden_dim, norm, activ))
        self.fcs.append(LinearBlock(hidden_dim + hidden_dim, out_dim, 'none', 'none'))
        self.tau_fcs.append(LinearBlock(cond_dim, hidden_dim, norm, activ))

    def forward(self, x, tau):
        for fc,tau_fc in zip(self.fcs, self.tau_fcs):
            tau_repr = tau_fc(tau)
            x = fc(torch.cat([x, tau_repr], 1))
        return x


class DiscreteTreatmentChildNet(nn.Module):
    def __init__(self, model_config, data_config):
        super().__init__()

        self.covariate_net = ConcatNet(data_config.covariate_dim, model_config.tau_embed_dim, 1, model_config.n_layers,
                                       model_config.hidden_dim, model_config.norm, model_config.activ, out_dim=model_config.hidden_dim)
        self.model = nn.ModuleList()
        for _ in range(data_config.discrete_range):
            self.model.append(ConcatNet(model_config.hidden_dim, model_config.hidden_dim, 1,
                                        model_config.n_layers+1, model_config.hidden_dim//2,
                                        model_config.norm, model_config.activ, out_dim=data_config.target_dim))
        self.model.apply(weights_init(model_config.init_type))
        self.covariate_net.apply(weights_init(model_config.init_type))

    def forward(self, covariate, treatment, tau, return_pair=False):

        h = self.covariate_net(covariate, tau)

        out = []
        for layer in self.model:
            out += [layer(h, tau)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(treatment.size(0))).to(covariate.device)
        out = out[idx, treatment.long().view(-1)]  # (batch, style_dim)
        if not return_pair:
            return out
        else:
            h_pos = h[treatment.view(-1)==1]
            h_neg = h[treatment.view(-1)==0]
            return h_pos, h_neg, out

class ContinuousTreatmentParentNet(nn.Module):
    def __init__(self, model_config, data_config):
        """
        We use different models for each treatment following existing methods
        We take in covariate, treatment and targets and outputs the quantile value for each sample in [0,1]
        :param model_config:
        :param data_config:
        """
        super().__init__()
        self.fcs = nn.ModuleList()
        self.treat_fcs = nn.ModuleList()
        self.target_fcs = nn.ModuleList()
        self.linear_input = LinearBlock(data_config.covariate_dim+data_config.treatment_dim+data_config.target_dim,
                                        model_config.hidden_dim, model_config.norm, model_config.activ)
        for _ in range(model_config.n_layers-1):
            self.fcs.append(LinearBlock(model_config.hidden_dim*3, model_config.hidden_dim, model_config.norm, model_config.activ))
            self.treat_fcs.append(LinearBlock(data_config.treatment_dim, model_config.hidden_dim, model_config.norm, model_config.activ))
            self.target_fcs.append(LinearBlock(data_config.target_dim, model_config.hidden_dim, model_config.norm, model_config.activ))
        self.fcs.append(LinearBlock(model_config.hidden_dim*3, 1, 'none', 'none'))
        self.treat_fcs.append(LinearBlock(data_config.treatment_dim, model_config.hidden_dim, model_config.norm, model_config.activ))
        self.target_fcs.append(LinearBlock(data_config.target_dim, model_config.hidden_dim, model_config.norm, model_config.activ))


    def forward(self, covariate, treatment, targets):
        h = torch.cat([covariate, treatment, targets], 1)
        h = self.linear_input(h)
        for fc,t_fc,y_fc in zip(self.fcs, self.treat_fcs, self.target_fcs):
            h_t = t_fc(treatment)
            h_y = y_fc(targets)
            h = torch.cat([h, h_t, h_y], 1)
            h = fc(h)
        return torch.sigmoid(h).view(-1)

class ContinuousTreatmentChildNet(nn.Module):
    def __init__(self, model_config, data_config):
        super().__init__()
        self.fcs = nn.ModuleList()
        self.tau_fcs = nn.ModuleList()
        self.treat_fcs = nn.ModuleList()
        self.linear_input = LinearBlock(data_config.covariate_dim+data_config.treatment_dim+1, model_config.hidden_dim,
                                        model_config.norm, model_config.activ)
        for _ in range(model_config.n_layers-1):
            self.fcs.append(LinearBlock(model_config.hidden_dim*3, model_config.hidden_dim, model_config.norm, model_config.activ))
            self.tau_fcs.append(LinearBlock(1, model_config.hidden_dim, model_config.norm, model_config.activ))
            self.treat_fcs.append(LinearBlock(data_config.treatment_dim, model_config.hidden_dim, model_config.norm, model_config.activ))

        self.fcs.append(nn.Linear(model_config.hidden_dim*3, 1))
        self.tau_fcs.append(LinearBlock(1, model_config.hidden_dim, model_config.norm, model_config.activ))
        self.treat_fcs.append(LinearBlock(data_config.treatment_dim, model_config.hidden_dim, model_config.norm, model_config.activ))


    def forward(self, covariate, treatment, tau, return_pair=False):
        covariate = covariate.view(len(covariate), -1)
        treatment = treatment.view(len(treatment), -1)
        tau = tau.view(len(tau), -1)
        h = torch.cat([covariate, treatment, tau], 1)
        h = self.linear_input(h)
        for tau_fc, t_fc, fc in zip(self.tau_fcs, self.treat_fcs, self.fcs):
            h_tau = tau_fc(tau)
            h_treat = t_fc(treatment)
            h = torch.cat([h, h_tau, h_treat], 1)
            h = fc(h)
        return h


class MultipleKernelMaximumMeanDiscrepancy(nn.Module):
    r"""The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in
    `Learning Transferable Features with Deep Adaptation Networks (ICML 2015) <https://arxiv.org/pdf/1502.02791>`_

    Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t`
    of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate
    activations as :math:`\{z_i^s\}_{i=1}^{n_s}` and :math:`\{z_i^t\}_{i=1}^{n_t}`.
    The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as

    .. math::
        D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k},

    :math:`k` is a kernel function in the function space

    .. math::
        \mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \}

    where :math:`k_{u}` is a single kernel.

    Using kernel trick, MK-MMD can be computed as

    .. math::
        \hat{D}_k(P, Q) &=
        \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s})\\
        &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t})\\
        &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}).\\

    Args:
        kernels (tuple(torch.nn.Module)): kernel functions.
        linear (bool): whether use the linear version of DAN. Default: False

    Inputs:
        - z_s (tensor): activations from the source domain, :math:`z^s`
        - z_t (tensor): activations from the target domain, :math:`z^t`

    Shape:
        - Inputs: :math:`(minibatch, *)`  where * means any dimension
        - Outputs: scalar

    .. note::
        Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape.

    .. note::
        The kernel values will add up when there are multiple kernels.

    Examples::

    """

    def __init__(self, kernels, linear= False):
        super(MultipleKernelMaximumMeanDiscrepancy, self).__init__()
        self.kernels = kernels
        self.index_matrix = None
        self.linear = linear

    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
        features = torch.cat([z_s, z_t], dim=0)
        batch_size = int(z_s.size(0))
        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device)


        kernel_matrix = sum([kernel(features) for kernel in self.kernels])  # Add up the matrix of each kernel
        # Add 2 / (n-1) to make up for the value on the diagonal
        # to ensure loss is positive in the non-linear version
        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)

        return loss


def _update_index_matrix(batch_size: int, index_matrix= None,
                         linear= True) -> torch.Tensor:
    r"""
    Update the `index_matrix` which convert `kernel_matrix` to loss.
    If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.
    Else return a new tensor with shape (2 x batch_size, 2 x batch_size).
    """
    if index_matrix is None or index_matrix.size(0) != batch_size * 2:
        index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)
        if linear:
            for i in range(batch_size):
                s1, s2 = i, (i + 1) % batch_size
                t1, t2 = s1 + batch_size, s2 + batch_size
                index_matrix[s1, s2] = 1. / float(batch_size)
                index_matrix[t1, t2] = 1. / float(batch_size)
                index_matrix[s1, t2] = -1. / float(batch_size)
                index_matrix[s2, t1] = -1. / float(batch_size)
        else:
            for i in range(batch_size):
                for j in range(batch_size):
                    if i != j:
                        index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))
                        index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))
            for i in range(batch_size):
                for j in range(batch_size):
                    index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)
                    index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)
    return index_matrix
