import numpy as np
from helpers.utils import to_torch
from torch.utils.data import DataLoader, TensorDataset
from helpers.hsic import MMR
from helpers.kernel import RBFKernel
import torch.optim as optim
import torch
import torch.nn as nn

import pytorch_lightning as pl


class GenNet(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size=16):
        super(GenNet, self).__init__()
        self.activation = nn.LeakyReLU(0.4)
        self.hidden_size = hidden_size
        self.net = nn.Sequential(nn.Linear(input_dim, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, output_dim),
                                 )

    def init_weights_inner(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.uniform_(m.weight, a=-1, b=1)
            torch.nn.init.uniform_(m.bias, a=0, b=0)

    def init_weights(self):
        self.net.apply(self.init_weights_inner)

    def forward(self, z):
        x = self.net(z)

        return x


class GenNetY(nn.Module):
    def __init__(self, Z_dim, hidden_size=64):
        super(GenNetY, self).__init__()
        self.activation = nn.Tanh()
        self.hidden_size = hidden_size
        self.net = nn.Sequential(nn.Linear(Z_dim, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, 1),
                                 )

    def init_weights_inner(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.uniform_(m.weight, a=-1, b=1)

    def init_weights(self):
        self.net.apply(self.init_weights_inner)

    def forward(self, z):
        x = self.net(z)

        return x


class VAE(pl.LightningModule):
    def __init__(self, X, z_dim, hidden_size=32, lr=1e-2):
        super().__init__()
        self.X = to_torch(X)
        self.lr = lr
        self.x_dim = self.X.shape[1]
        self.z_dim = z_dim

        self.activation = nn.LeakyReLU()
        self.hidden_size = hidden_size

        self.encoder = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     )

        self.decoder = nn.Sequential(nn.Linear(z_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.x_dim)
                                     )

        self.fc_mu = nn.Linear(self.hidden_size, z_dim)
        self.fc_var = nn.Linear(self.hidden_size, z_dim)

        self.mse_loss = nn.MSELoss(reduction='none')

        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def encode(self, X, return_var=True):
        result = self.encoder(X)

        mu = self.fc_mu(result)
        if not return_var:
            return mu

        log_var = self.fc_var(result)

        return mu, log_var

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        log_pxz = dist.log_prob(x)
        return log_pxz.sum(dim=1)

    def kl_divergence(self, z, mu, std):
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        kl = (log_qzx - log_pz)
        kl = kl.sum(dim=1)
        return kl

    def decode(self, z):
        result = self.decoder(z)

        return result

    def training_step(self, X, batch_idx):
        x = X[0]

        # encode x to get the mu and variance parameters
        mu, log_var = self.encode(x)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        x_hat = self.decoder(z)

        # reconstruction loss
        recon_loss = self.mse_loss(x_hat, x).sum(dim=1)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl + recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(),
            'reconstruction': recon_loss.mean()
        }, prog_bar=True, on_step=False, on_epoch=True)

        return elbo

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.X)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class AutoEncoder(pl.LightningModule):
    def __init__(self, X, A, z_dim, hidden_size=32, lr=1e-2, lmd=1e-2, accelerator='mps'):
        super().__init__()
        self.X = to_torch(X)
        self.A = to_torch(A)
        self.lr = lr
        self.lmd = lmd
        self.x_dim = self.X.shape[1]
        self.z_dim = z_dim
        self.accelerator = accelerator

        self.activation = nn.LeakyReLU()
        self.hidden_size = hidden_size

        self.encoder = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, z_dim)
                                     )

        self.decoder = nn.Sequential(nn.Linear(z_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.x_dim)
                                     )

        self.kernel_a = RBFKernel(1)
        self.mse_loss = nn.MSELoss()

    def encode(self, x):
        return self.encoder(x)

    def forward(self, x):
        # Encode
        z = self.encoder(x)

        # Decode
        x = self.decoder(z)
        return x

    def training_step(self, batch, batch_idx):
        x, a = batch
        loss = self.loss(x, a)
        self.log('loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def get_res_Z(self, X, A, predZ=None, ret_W=False):
        A_1 = torch.hstack([torch.ones(size=[A.shape[0], 1], device=self.accelerator), A])
        if predZ is None:
            predZ = self.encoder(X)

        W = torch.linalg.inv(A_1.T @ A_1) @ A_1.T @ predZ
        predZA = A_1 @ W
        resZ = predZ - predZA

        if ret_W:
            return resZ, W
        else:
            return resZ

    def get_MMR_loss(self, X, A, predZ=None):
        resZ = self.get_res_Z(X, A, predZ)

        return MMR(resZ, A, kernel_A=self.kernel_a)

    def get_mse_loss(self, X, return_predZ=False):
        predZ = self.encoder(X)
        predX = self.decoder(predZ)

        if return_predZ:
            return self.mse_loss(predX, X), predZ
        else:
            return self.mse_loss(predX, X)

    def loss(self, X, A):

        mse_loss, predZ = self.get_mse_loss(X, return_predZ=True)
        MMR_loss = self.get_MMR_loss(X, A, predZ=predZ)

        loss = mse_loss + self.lmd * MMR_loss

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.X, self.A)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class EncoderOracle(pl.LightningModule):
    def __init__(self, X, Z, hidden_size=32, lr=1e-2, lmd=1e-2):
        super().__init__()
        self.X = to_torch(X)
        self.Z = to_torch(Z)
        self.lr = lr
        self.lmd = lmd
        self.x_dim = self.X.shape[1]
        self.z_dim = self.Z.shape[1]

        self.activation = nn.LeakyReLU()
        self.hidden_size = hidden_size

        self.encoder = nn.Sequential(nn.Linear(self.x_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.z_dim)
                                     )

        self.mse_loss = nn.MSELoss()

    def encode(self, x):
        return self.encoder(x)

    def forward(self, x):
        z = self.encoder(x)
        return z

    def training_step(self, batch, batch_idx):
        x, z = batch
        loss = self.loss(x, z)
        self.log('loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def loss(self, X, Z):
        pred_Z = self.encode(X)
        loss = self.mse_loss(pred_Z, Z)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.X, self.Z)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class DecoderOracle(pl.LightningModule):
    def __init__(self, X, Z, hidden_size=32, lr=1e-2, lmd=1e-2):
        super().__init__()
        self.X = to_torch(X)
        self.Z = to_torch(Z)
        self.lr = lr
        self.lmd = lmd
        self.x_dim = self.X.shape[1]
        self.z_dim = self.Z.shape[1]

        self.activation = nn.LeakyReLU()
        self.hidden_size = hidden_size

        self.decoder = nn.Sequential(nn.Linear(self.z_dim, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                     nn.Linear(self.hidden_size, self.x_dim)
                                     )

        self.mse_loss = nn.MSELoss()

    def decode(self, x):
        return self.decoder(x)

    def forward(self, z):
        x = self.decode(z)
        return x

    def training_step(self, batch, batch_idx):
        x, z = batch
        loss = self.loss(x, z)
        self.log('loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def loss(self, X, Z):
        pred_X = self.decode(Z)
        loss = self.mse_loss(pred_X, X)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.X, self.Z)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class AdditiveMLP(pl.LightningModule):
    def __init__(self, Y, Z, V, lr=1e-2):
        super().__init__()
        self.Y = to_torch(Y)
        self.Z = to_torch(Z)
        self.V = to_torch(V)
        self.lr = lr
        self.z_dim = self.Z.shape[1]

        self.activation = nn.LeakyReLU()
        self.hidden_size = 32

        self.f_Z = nn.Sequential(nn.Linear(self.z_dim, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, 1)
                                 )

        self.f_V = nn.Sequential(nn.Linear(self.z_dim, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                                 nn.Linear(self.hidden_size, 1)
                                 )

        self.mse_loss = nn.MSELoss()

    def forward(self, Z, V=None):
        if V is None:
            f_Z = self.f_Z(Z)
            return f_Z
        else:
            f_Z, f_V = self.f_Z(Z), self.f_V(V)

            return f_Z + f_V

    def training_step(self, batch, batch_idx):
        Y, Z, V = batch
        loss = self.loss(Y, Z, V)
        self.log('loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def loss(self, Y, Z, V):
        pred_Y = self.forward(Z, V).flatten()

        loss = self.mse_loss(pred_Y, Y)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.Y, self.Z, self.V)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class MLP(pl.LightningModule):
    def __init__(self, Y, A, lr=1e-2):
        super().__init__()
        self.Y = to_torch(Y)
        self.A = to_torch(A)
        self.lr = lr
        self.A_dim = self.A.shape[1]

        self.activation = nn.LeakyReLU()
        self.hidden_size = 32

        self.f = nn.Sequential(nn.Linear(self.A_dim, self.hidden_size), self.activation,
                               nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                               nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                               nn.Linear(self.hidden_size, self.hidden_size), self.activation,
                               nn.Linear(self.hidden_size, 1)
                               )

        self.mse_loss = nn.MSELoss()

    def forward(self, A):
        predY = self.f(A)

        return predY

    def training_step(self, batch, batch_idx):
        Y, A = batch
        loss = self.loss(Y, A)
        self.log('loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def loss(self, Y, A):
        pred_Y = self.forward(A).flatten()

        loss = self.mse_loss(pred_Y, Y)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        self.train_dataset = TensorDataset(self.Y, self.A)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=256, shuffle=True)


class CFEstimator:

    def __init__(self, additive_NN, LR_PredZA):
        self.additive_NN = additive_NN
        self.lr = LR_PredZA
        self.mean_predY = None
        self.mean_Y = None

    def fit_bias(self, predZ, Y):
        self.mean_predY = self.additive_NN(to_torch(predZ)).mean().item()
        self.mean_Y = Y.mean()

    def predict(self, A_star, V):
        pred_ZA = self.lr.predict(np.array(A_star).reshape(1, -1))
        ret = self.additive_NN(to_torch(pred_ZA + V)).mean() - (self.mean_predY - self.mean_Y)
        return ret.item()
