import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm,weight_norm
import numpy as np




class AdaBlock(nn.Module):
    def __init__(self, input_dim, cond_dim, output_dim):
        super(AdaBlock, self).__init__()
        self.linear = nn.Linear(input_dim+input_dim, output_dim)
        self.style_linear = nn.Linear(cond_dim, input_dim)
        self.act = nn.LeakyReLU(0.1)

    def forward(self, x, y):
        x = torch.cat([x,self.style_linear(y)], dim=1)
        x = self.linear(x)
        x = self.act(x)
        return x



class EstimatorNet(nn.Module):
    def __init__(self, cov_dim, treat_dim, out_dim, n_layers, hidden_dim=512):
        super(EstimatorNet, self).__init__()

        self.input_layer = nn.Sequential(nn.Linear(cov_dim+treat_dim+out_dim, hidden_dim), nn.SiLU())
        model = []
        for _ in range(n_layers-1):
            model += [ResidualBlock(hidden_dim, hidden_dim)]
        self.model = nn.ModuleList(model)
        self.final = (nn.Linear(hidden_dim, out_dim))

    def forward(self, covariate, treatment, outcome):
        x = torch.cat([covariate, treatment, outcome], dim=1)
        x = self.input_layer(x)
        for block in self.model:
            x = block(x)
        return torch.sigmoid(self.final(x))

class DiscreteEstimatorNet(nn.Module):
    def __init__(self, cov_dim, treat_dim, out_dim, n_layers, hidden_dim=512):
        super(DiscreteEstimatorNet, self).__init__()

        self.input_layer = nn.Sequential(nn.Linear(cov_dim+treat_dim+out_dim, hidden_dim), nn.SiLU())
        model = []
        for _ in range(n_layers-1):
            model += [ResidualBlock(hidden_dim, hidden_dim)]
        model += [nn.Linear(hidden_dim, 1), nn.Sigmoid()]
        self.model = nn.Sequential(*model)

        model = []
        for _ in range(n_layers-1):
            model += [ResidualBlock(hidden_dim, hidden_dim)]
        model += [nn.Linear(hidden_dim, 1), nn.Sigmoid()]
        self.model2 = nn.Sequential(*model)

    def forward(self, covariate, treatment, outcome):
        x = torch.cat([covariate, treatment, outcome], dim=1)
        x = self.input_layer(x)
        out = [self.model(x), self.model2(x)]
        out = torch.stack(out, 1)
        idx = torch.LongTensor(range(len(outcome))).to(outcome.device)
        out = out[idx,treatment.long().squeeze()]
        return out





class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ResidualBlock, self).__init__()
        self.linear1 = nn.Linear(input_dim, output_dim)
        self.linear2 = nn.Linear(output_dim, output_dim)
        self.act = nn.SiLU()
        self.norm1 = nn.LayerNorm(output_dim)
        self.norm2 = nn.LayerNorm(output_dim)
    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.act(x)
        x = self.linear2(x)
        x = self.norm2(x)
        x = self.act(x)
        x = (residual + x)/math.sqrt(2)
        return x

class AdaResidualBlock(nn.Module):
    def __init__(self, input_dim, cond_dim, output_dim):
        super(AdaResidualBlock, self).__init__()
        self.linear1 = nn.Linear(input_dim+cond_dim, output_dim)
        self.linear2 = nn.Linear(output_dim+cond_dim, output_dim)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x, y):
        residual = x
        x = self.act(x)
        x = self.linear1(torch.cat([x,y], 1))
        x = self.act(x)
        x = self.linear2(torch.cat([x,y], 1))
        x = (residual + x)/math.sqrt(2)
        return x

class RegressorNet(nn.Module):
    def __init__(self, cov_dim, treat_dim, out_dim, n_layers, hidden_dim=512):
        super(RegressorNet, self).__init__()
        self.cov_input = nn.Sequential(nn.Linear(cov_dim, hidden_dim), nn.SiLU())
        self.treat_input = nn.Sequential(nn.Linear(treat_dim, hidden_dim), nn.SiLU())
        self.tau_input = nn.Sequential(nn.Linear(out_dim, hidden_dim), nn.SiLU())
        self.input_layer = nn.Sequential(nn.Linear(hidden_dim*3, hidden_dim), nn.SiLU())
        self.model = []
        self.tau_model =  []
        for _ in range(n_layers-1):
            self.model += [ResidualBlock(hidden_dim, hidden_dim)]
            self.tau_model += [nn.Linear(hidden_dim, hidden_dim), nn.SiLU()]
        self.model = nn.ModuleList(self.model)
        self.tau_model = nn.ModuleList(self.tau_model)
        self.final = (nn.Linear(hidden_dim, out_dim))

    def forward(self, covariate, treatment, tau):
        cov = self.cov_input(covariate)
        treat = self.treat_input(treatment)
        tau = self.tau_input(tau)
        x = torch.cat([cov, treat, tau], dim=1)
        x = self.input_layer(x)
        for block,tau_block in zip(self.model,self.tau_model):
            x = block(x)
        return self.final(x)

class DiscreteRegressorNet(nn.Module):
    def __init__(self, cov_dim, treat_dim, out_dim, n_layers, hidden_dim=512):
        super(DiscreteRegressorNet, self).__init__()
        self.cov_input = nn.Sequential(nn.Linear(cov_dim, hidden_dim), nn.SiLU())
        self.treat_input = nn.Sequential(nn.Linear(treat_dim, hidden_dim), nn.SiLU())
        self.tau_input = nn.Sequential(nn.Linear(out_dim, hidden_dim), nn.SiLU())
        self.input_layer = nn.Sequential(nn.Linear(hidden_dim*3, hidden_dim), nn.SiLU())
        self.model = []
        for _ in range(n_layers-1):
            self.model += [ResidualBlock(hidden_dim, hidden_dim)]
        self.model += [nn.Linear(hidden_dim, out_dim)]
        self.model = nn.Sequential(*self.model)

        self.model2 = []
        for _ in range(n_layers - 1):
            self.model2 += [ResidualBlock(hidden_dim, hidden_dim)]
        self.model2 += [nn.Linear(hidden_dim, out_dim)]
        self.model2 = nn.Sequential(*self.model2)

    def forward(self, covariate, treatment, tau):
        cov = self.cov_input(covariate)
        treat = self.treat_input(treatment)
        tau = self.tau_input(tau)
        x = torch.cat([cov, treat, tau], dim=1)
        x = self.input_layer(x)
        out = [self.model(x), self.model2(x)]
        out = torch.stack(out, 1)
        idx = torch.LongTensor(range(len(tau))).to(tau.device)
        out = out[idx,treatment.long().squeeze()]
        return out


class ResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
                 normalize=True, downsample=False, upsample=False):
        super().__init__()
        self.actv = actv
        self.normalize = normalize
        self.downsample = downsample
        self.upsample = upsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out)

    def _build_weights(self, dim_in, dim_out):
        self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        if self.normalize:
            self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
            self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.learned_sc:
            x = self.conv1x1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        return x

    def _residual(self, x):
        if self.normalize:
            x = self.norm1(x)
        x = self.actv(x)
        x = self.conv1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        if self.normalize:
            x = self.norm2(x)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x):
        x = self._shortcut(x) + self._residual(x)
        return x / math.sqrt(2)  # unit variance



class ConvEstimatorNet(nn.Module):
    def __init__(self, size=32, hidden_dim=32):
        super(ConvEstimatorNet, self).__init__()
        self.treat_proj = nn.Linear(1, size**2)
        model = []
        use_dim = hidden_dim
        model += [nn.Conv2d(7, use_dim, 4, 2, 1),  nn.LeakyReLU(0.1)]
        for _ in range(int(np.log2(size))-3):
            model += [nn.Conv2d(use_dim, use_dim*2, 4, 2, 1), nn.LeakyReLU(0.1)]
            #model += [ResBlk(use_dim, use_dim, downsample=True)]
            use_dim = min(512, use_dim*2)

        model += [nn.Conv2d(use_dim, 1, 4, 1, 0),  nn.Sigmoid()]
        self.model = nn.Sequential(*model)
        self.size = size
        print(self.model)

    def forward(self, covariate, treatment, outcome):
        size = self.size
        treatment = self.treat_proj(treatment).view(len(treatment), 1, size, size)
        x = torch.cat([covariate, treatment, outcome], dim=1)
        x = self.model(x).view(-1,1)
        return x

class ConvRegressorNet(nn.Module):
    def __init__(self, size=32, hidden_dim=32):
        super().__init__()
        self.treat_proj = nn.Sequential(nn.Linear(1, size**2))
        self.tau_proj = nn.Sequential(nn.Linear(1, size**2))

        self.size = size
        use_dim = hidden_dim
        model = []
        model += [nn.Conv2d(5, use_dim, 4,2 , 1), nn.LeakyReLU(0.1)]
        for _ in range(int(np.log2(size))-3):
            model += [nn.Conv2d(use_dim, use_dim*2, 4, 2, 1),  nn.LeakyReLU(0.1)]
            use_dim = min(512, use_dim*2)
        for _ in range(int(np.log2(size))-3):
            model += [nn.ConvTranspose2d(use_dim, use_dim//2, 4, 2, 1),nn.LeakyReLU(0.1)]
            use_dim = use_dim//2

        model += [nn.ConvTranspose2d(use_dim, 3, 4, 2, 1)]
        self.model = nn.Sequential(*model)
        print(self.model)

    def forward(self, covariate, treatment, tau):
        size = self.size
        treatment = self.treat_proj(treatment).view(len(covariate), 1, size, size)
        tau = self.tau_proj(tau).view(len(covariate), 1, size, size)
        x = torch.cat([covariate, treatment, tau], dim=1)
        x = self.model(x)
        return x


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = torch.randn(embed_dim // 2) * scale
    def forward(self, x):
        x_proj = x * self.W[None, :].to(x.device) * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""
    def __init__(self, input_dim, output_dim, one_D):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)
        self.one_D = one_D
    def forward(self, x):
        if self.one_D:
            return self.dense(x)[..., None]
        else:
            return self.dense(x)[..., None, None]


class UnetRegressor(nn.Module):
    def __init__(self, channels=[32, 64, 128, 256], embed_dim=256, one_D=False, out_dims=3, conditional=True,
                 input_dim=3):
        """
        if conditional : uses an estimate of the noise variable to predict the outcome
        """
        super().__init__()

        self.conv = nn.Conv2d
        self.tconv = nn.ConvTranspose2d
        self.treatment_embedder = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
                                                nn.Linear(embed_dim, embed_dim))
        self.tau_embedder = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
                                            nn.Linear(embed_dim, embed_dim))

        if conditional:  # EXTRA dimension for the noise estimate
            embed_dim = embed_dim  * 2

        # Encoding layers where the resolution decreases
        self.conv1 = self.conv(input_dim, channels[0], 3, stride=1, padding=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0], one_D)
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = self.conv(channels[0], channels[1], 3, stride=2, padding=1, bias=False)
        self.dense2 = Dense(embed_dim, channels[1], one_D)
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = self.conv(channels[1], channels[2], 3, stride=2, padding=1, bias=False)
        self.dense3 = Dense(embed_dim, channels[2], one_D)
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = self.conv(channels[2], channels[3], 3, stride=2, padding=1, bias=False)
        self.dense4 = Dense(embed_dim, channels[3], one_D)
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = self.tconv(channels[3], channels[2], 4, stride=2, padding=1, bias=False)
        self.dense5 = Dense(embed_dim, channels[2], one_D)
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = self.tconv(channels[2] + channels[2], channels[1], 4, stride=2, bias=False, padding=1)
        self.dense6 = Dense(embed_dim, channels[1], one_D)
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = self.tconv(channels[1] + channels[1], channels[0], 4, stride=2, bias=False, padding=1)
        self.dense7 = Dense(embed_dim, channels[0], one_D)
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = self.tconv(channels[0] + channels[0], out_dims, 3, stride=1, padding=1)

        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)

        # hdim = int(embed_dim/2)
        # self.treatment_embedder = nn.Sequential(nn.Linear(1,hdim),nn.ReLU(),nn.Linear(hdim,hdim))

    def forward(self, X, T, cond=None):
        # x = x.permute(0,2,1)
        embed = self.act(self.treatment_embedder(T.float()))
        if cond is not None:
            cond = self.act(self.tau_embedder(cond.float()))
            embed = torch.cat((embed, cond), -1)
        # Obtain the Gaussian random feature embedding for t
        # outcomes = []
        # for i_class in torch.arange(self.num_classes,device = X.device):
        # z_embed = self.Embeds(i_class)[None,:].repeat((T.shape[0],1))
        # embed = torch.cat((treat_embed,z_embed),1)

        # Encoding path
        h1 = self.conv1(X)
        ## Incorporate information from t
        h1 += self.dense1(embed)
        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)
        # Decoding path
        h = self.tconv4(h4)
        ## Skip connection from the encoding path
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))
        return h


