import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch
import torch.nn.functional as F
import tqdm


class LogitsBatchIterator:
    def __init__(self, logits, batch_size):
        self.logits = logits
        self.batch_size = batch_size
        self.current_index = 0

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= len(self.logits):
            raise StopIteration

        # Calculate the end index for the batch
        end_index = min(self.current_index + self.batch_size, len(self.logits))

        # Slice the logits tensor to get the current batch
        batch_logits = self.logits[self.current_index : end_index]

        # Update the current index to the end of this batch
        self.current_index = end_index

        return batch_logits


class backbone(nn.Module):
    def __init__(self, model, cls_output_dim, z_dim=512):
        super(backbone, self).__init__()
        self.model = model
        self.linear = nn.Linear(z_dim, cls_output_dim)

    def forward(self, x, return_type="logits"):
        if return_type == "features":
            return self.model(x)
        elif return_type == "logits":
            return self.linear(nn.functional.relu(self.model(x)))
        else:
            raise ValueError(
                "return_type must be either 'features' or 'logits'"
            )

    def get_grads(self) -> torch.Tensor:
        """
        Returns all the gradients concatenated in a single tensor.

        Returns:
            gradients tensor
        """
        return torch.cat(self.get_grads_list())

    def get_grads_list(self):
        """
        Returns a list containing the gradients (a tensor for each layer).

        Returns:
            gradients list
        """
        grads = []
        for pp in list(self.parameters()):
            grads.append(pp.grad.view(-1))
        return grads
    # https://github.com/aimagelab/mammoth/blob/170cea9de1a75c5b22c297fbb425b4e8aab2bd7d/backbone/__init__.py#L99
    def get_params(self) -> torch.Tensor:
        """
        Returns all the parameters concatenated in a single tensor.

        Returns:
            parameters tensor
        """
        params = []
        for pp in list(self.parameters()):
            params.append(pp.view(-1))
        return torch.cat(params)

    def set_grads(self, new_grads: torch.Tensor) -> None:
        """
        Sets the parameters to a given value.

        Args:
            new_params: concatenated values to be set
        """
        assert new_grads.size() == self.get_params().size()
        progress = 0
        for pp in list(self.parameters()):
            cand_grads = new_grads[
                progress : progress + torch.tensor(pp.size()).prod()
            ].view(pp.size())
            progress += torch.tensor(pp.size()).prod()
            pp.grad = cand_grads

    def get_grads(self) -> torch.Tensor:
        """
        Returns all the gradients concatenated in a single tensor.

        Returns:
            gradients tensor
        """
        grads = []
        for pp in list(self.parameters()):
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)


class CRNN(nn.Module):
    def __init__(
        self,
        in_channels=6,
        num_classes=2,
        cnn_embed_dims=[64, 64],
        embed_dims=50,
    ):
        super(CRNN, self).__init__()
        if isinstance(cnn_embed_dims, int):
            cnn_embed_dims = [cnn_embed_dims]
        module_list = []
        for i, out_channels in enumerate(cnn_embed_dims):
            if i == 0:
                in_channels = in_channels
            else:
                in_channels = cnn_embed_dims[i - 1]
            module_list.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                )
            )
            module_list.append(nn.BatchNorm1d(out_channels))
            module_list.append(nn.ReLU())
            module_list.append(nn.MaxPool1d(2))

        self.lstm = nn.LSTM(
            input_size=cnn_embed_dims[-1],
            hidden_size=embed_dims,
            batch_first=True,
        )
        self.fc = nn.Linear(embed_dims, num_classes)
        self.cnn = nn.Sequential(*module_list)
        self.init_weights()

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LSTM):
                for name, param in module.named_parameters():
                    if "weight" in name:
                        nn.init.orthogonal_(param)
                    elif "bias" in name:
                        nn.init.constant_(param, 0)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        b, h, c = x.shape
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        return x


class CVAE(nn.Module):
    def __init__(
        self,
        num_timestep,
        input_channels,
        hidden_dims,
        latent_dim,
        cls_dim,
        class_num,
    ):
        super(CVAE, self).__init__()

        self.hidden_dims = hidden_dims
        self.num_timestep = num_timestep
        self.input_channels = input_channels
        self.cls_dim = cls_dim

        encoder_list = []
        for i, out_channels in enumerate(hidden_dims):
            if i == 0:
                in_channels = input_channels
            else:
                in_channels = hidden_dims[i - 1]
            encoder_list.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            # encoder_list.append(nn.BatchNorm1d(out_channels))
            encoder_list.append(nn.ReLU())

        # Encoder
        self.encoder = nn.Sequential(*encoder_list)

        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)

        self.classifier = nn.Linear(cls_dim, class_num)

        decoder_layers = []
        reversed_hidden_dims = list(reversed(hidden_dims))

        decoder_layers.append(
            nn.Linear(
                latent_dim,
                hidden_dims[-1] * num_timestep // (2 ** len(hidden_dims)),
            )
        )
        decoder_layers.append(
            nn.Unflatten(
                dim=-1,
                unflattened_size=(
                    hidden_dims[-1],
                    num_timestep // (2 ** len(hidden_dims)),
                ),
            )
        )

        for i in range(len(reversed_hidden_dims) - 1):
            in_channels = reversed_hidden_dims[i]
            out_channels = reversed_hidden_dims[i + 1]

            decoder_layers.append(
                nn.ConvTranspose1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                )
            )
            decoder_layers.append(nn.BatchNorm1d(out_channels))
            decoder_layers.append(nn.ReLU())

        # To match the original input size
        decoder_layers.append(
            nn.ConvTranspose1d(
                reversed_hidden_dims[-1],
                input_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            )
        )
        decoder_layers.append(nn.BatchNorm1d(input_channels))

        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x):
        x = x.permute(0, 2, 1)
        x = self.encoder(x)
        self.x_size = x.size(2)
        x = F.max_pool1d(x, self.x_size).squeeze(2)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def classify(self, z):
        return self.classifier(z[:, : self.cls_dim])

    def decode(self, z):
        # x = self.decoder_input(z)
        # x = x.view(-1, self.hidden_dims[-1], self.num_timestep // (2 ** len(self.hidden_dims)))
        x = self.decoder(z)
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


class CNNTimeseriesVAE(nn.Module):
    def __init__(self, num_timestep, input_channels, hidden_dims, latent_dim):
        super(CNNTimeseriesVAE, self).__init__()

        self.hidden_dims = hidden_dims
        self.num_timestep = num_timestep
        self.input_channels = input_channels

        encoder_list = []
        for i, out_channels in enumerate(hidden_dims):
            if i == 0:
                in_channels = input_channels
            else:
                in_channels = hidden_dims[i - 1]
            encoder_list.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            encoder_list.append(nn.BatchNorm1d(out_channels))
            encoder_list.append(nn.ReLU())

        # Encoder
        self.encoder = nn.Sequential(*encoder_list)

        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)

        decoder_layers = []
        reversed_hidden_dims = list(reversed(hidden_dims))

        decoder_layers.append(
            nn.Linear(
                latent_dim,
                hidden_dims[-1] * num_timestep // (2 ** len(hidden_dims)),
            )
        )
        decoder_layers.append(
            nn.Unflatten(
                dim=-1,
                unflattened_size=(
                    hidden_dims[-1],
                    num_timestep // (2 ** len(hidden_dims)),
                ),
            )
        )

        for i in range(len(reversed_hidden_dims) - 1):
            in_channels = reversed_hidden_dims[i]
            out_channels = reversed_hidden_dims[i + 1]

            decoder_layers.append(
                nn.ConvTranspose1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                )
            )
            decoder_layers.append(nn.BatchNorm1d(out_channels))
            decoder_layers.append(nn.ReLU())

        # To match the original input size
        decoder_layers.append(
            nn.ConvTranspose1d(
                reversed_hidden_dims[-1],
                input_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            )
        )
        decoder_layers.append(nn.BatchNorm1d(input_channels))

        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x):
        x = x.permute(0, 2, 1)
        x = self.encoder(x)
        self.x_size = x.size(2)
        x = F.max_pool1d(x, self.x_size).squeeze(2)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        # x = self.decoder_input(z)
        # x = x.view(-1, self.hidden_dims[-1], self.num_timestep // (2 ** len(self.hidden_dims)))
        x = self.decoder(z)
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def fit(
        self,
        x: torch.Tensor,
        epochs,
        batch_size,
        valid_x=None,
        lr=0.001,
        alpha=0.5,
    ):
        self.train()
        print(f"Using device: {device}")
        self.to(device)

        if valid_x is not None:
            valid_x = valid_x.to(device)

        train_loss = []
        valid_loss = []
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        for epoch in tqdm(range(epochs)):
            total_loss = 0
            for count, i in enumerate(range(0, len(x), batch_size)):
                # Assuming torso_data is of the shape (batch_size, timestep, features)
                local_x = x[i : i + batch_size].to(device)
                optimizer.zero_grad()

                recon_batch, mu, log_var = self(local_x)
                recon_loss = F.mse_loss(recon_batch, local_x, reduction="sum")
                kld_loss = -0.5 * torch.sum(
                    1 + log_var - mu.pow(2) - log_var.exp()
                )
                loss = alpha * recon_loss + (1 - alpha) * kld_loss

                loss.backward()
                total_loss += loss.item()
                optimizer.step()

            if valid_x is not None:
                with torch.no_grad():

                    recon_batch, mu, log_var = self(valid_x)
                    recon_loss = F.mse_loss(
                        recon_batch, valid_x, reduction="sum"
                    )
                    kld_loss = -0.5 * torch.sum(
                        1 + log_var - mu.pow(2) - log_var.exp()
                    )
                    loss = recon_loss + kld_loss
                # print(f"Epoch {epoch+1}, train loss: {total_loss / len(x)}, valid loss: {loss}")
                train_loss.append(total_loss / len(x))
                valid_loss.append(loss.item() / len(valid_x))
            else:
                pass
        return train_loss, valid_loss


def joint_uncond_simple(
    decoder,
    classifier,
    device,
    M=10,
    Nalpha=25,
    Nbeta=100,
    z_dim=8,
    K=1,
    L=7,
    eps=1e-8,
    is_logit=False,
    is_rom=False,
):
    """
    M: number of classes
    Nalpha: number of samples of alpha
    Nbeta: number of samples of beta
    z_dim: dimension of z (latent)
    K: number of alpha
    L: number of beta
    """

    I = 0.0
    q = torch.zeros(M).to(device)

    for i in range(Nalpha):
        alpha = np.random.randn(K)
        zs = np.zeros((Nbeta, z_dim))

        for j in range(Nbeta):
            beta = np.random.randn(L)
            zs[j, :K] = alpha
            zs[j, K:] = beta

        # decode and classify a batch of Nbeta samples with the same alpha
        xhat = decoder(torch.from_numpy(zs).float().to(device)).permute(0, 2, 1)
        if is_rom:
            yhat, _ = classifier(xhat)
        else:
            _, yhat = classifier(xhat)
        if is_logit:
            # softmax to make sure positive only:
            yhat = F.softmax(yhat, dim=1)
        p = 1.0 / float(Nbeta) * torch.sum(yhat, 0)  # estimate of p(y|alpha)
        I += 1.0 / float(Nalpha) * torch.sum(p * torch.log(p + eps))
        q += 1.0 / float(Nalpha) * p  # accumulate estimate of p(y)

    I -= torch.sum(q * torch.log(q + eps))
    negCausalEffect = -I
    info = {"xhat": xhat, "yhat": yhat}

    return negCausalEffect, info


class CNN_v1(nn.Module):
    def __init__(
        self,
        in_channels=6,
        rom_classes=4,
        form_embed_dim=10,
        cnn_embed_dims=[64, 64],
    ):
        super(CNN_v1, self).__init__()
        if isinstance(cnn_embed_dims, int):
            cnn_embed_dims = [cnn_embed_dims]
        module_list = []
        for i, out_channels in enumerate(cnn_embed_dims):
            if i == 0:
                in_channels = in_channels
            else:
                in_channels = cnn_embed_dims[i - 1]
            module_list.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                )
            )
            module_list.append(nn.BatchNorm1d(out_channels))
            module_list.append(nn.ReLU())
            module_list.append(nn.MaxPool1d(2))

        self.form_embed_dim = form_embed_dim
        self.fc_rom = nn.Linear(cnn_embed_dims[-1], rom_classes)
        self.fc_form = nn.Linear(form_embed_dim, 1)
        self.cnn = nn.Sequential(*module_list)
        self.init_weights()

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LSTM):
                for name, param in module.named_parameters():
                    if "weight" in name:
                        nn.init.orthogonal_(param)
                    elif "bias" in name:
                        nn.init.constant_(param, 0)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        form = self.fc_form(x[:, : self.form_embed_dim])
        form = F.sigmoid(form)
        rom = self.fc_rom(x)
        return rom, form


class MLP_2layer(nn.Module):
    def __init__(self, timestep=150, in_channels=6):
        super(MLP_2layer, self).__init__()
        self.fc1 = nn.Linear(timestep * in_channels, 30)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(30, 1)

    def forward(self, x: torch.Tensor):
        x = x.reshape((x.shape[0], -1))
        x = self.fc1(x)
        x = self.relu1(x)
        x = F.dropout(x, p=0.5)
        x = self.fc2(x)
        form = F.sigmoid(x)
        return None, form


class ContrastiveCNN(nn.Module):
    def __init__(self, num_timestep, input_channels, hidden_dims, latent_dim):
        super(ContrastiveCNN, self).__init__()

        self.hidden_dims = hidden_dims
        self.num_timestep = num_timestep
        self.input_channels = input_channels

        encoder_list = []
        for i, out_channels in enumerate(hidden_dims):
            if i == 0:
                in_channels = input_channels
            else:
                in_channels = hidden_dims[i - 1]
            encoder_list.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )
            # encoder_list.append(nn.BatchNorm1d(out_channels))
            encoder_list.append(
                nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
            )
            encoder_list.append(nn.ReLU())

        # Encoder
        self.encoder = nn.Sequential(*encoder_list)
        self.fc_embedding = nn.Linear(hidden_dims[-1], latent_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.encoder(x)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        z = self.fc_embedding(x)
        return z


import math


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)


class CNNAttention(nn.Module):
    def __init__(
        self,
        num_timestep,
        input_channels,
        cnn_hidden_dims,
        num_attn_layers,
        attn_num_heads,
        latent_dim,
        dropout=0.5,
    ):
        super(CNNAttention, self).__init__()

        self.num_timestep = num_timestep
        self.input_channels = input_channels
        self.using_cnn = False
        self.dropout = nn.Dropout(dropout)
        # CNN module definition
        if cnn_hidden_dims is not None:
            encoder_list = []
            for i, out_channels in enumerate(cnn_hidden_dims):
                if i == 0:
                    in_channels = input_channels
                else:
                    in_channels = cnn_hidden_dims[i - 1]
                encoder_list.append(
                    nn.Conv1d(
                        in_channels,
                        out_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    )
                )
                encoder_list.append(
                    nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
                )
                encoder_list.append(nn.ReLU())

            self.cnn_module = nn.Sequential(*encoder_list)
            self.using_cnn = True
        else:
            self.cnn_module = nn.Linear(input_channels, latent_dim)
            cnn_hidden_dims = [latent_dim]

        self.hidden_dims = cnn_hidden_dims
        # Attention module
        self.attention_modules = nn.ModuleList(
            [
                nn.MultiheadAttention(
                    embed_dim=cnn_hidden_dims[-1],
                    num_heads=attn_num_heads,
                    dropout=dropout,
                )
                for _ in range(num_attn_layers)
            ]
        )

        self.fc_embedding = nn.Linear(cnn_hidden_dims[-1], latent_dim)

    def forward(self, x):
        if self.using_cnn:
            x = x.permute(0, 2, 1)  # Change from (N, C, L) to (N, L, C) for CNN
            x = self.cnn_module(x)
            x = x.permute(0, 2, 1)  # Change back to (N, C, L) after CNN
        else:
            x = self.cnn_module(x)
        # Applying the attention layers
        for attn in self.attention_modules:
            x, _ = attn(
                x, x, x
            )  # key, query, and value are the same for self-attention

        # Pooling and embedding
        x = F.max_pool1d(x.permute(0, 2, 1), x.size(1)).squeeze(2)
        z = self.fc_embedding(x)
        z = self.dropout(z)
        return z

    def __call__(self, x):
        return self.forward(x)


class SimpleMLP(nn.Module):
    def __init__(
        self,
        input_embedding_dim,
        hidden_layers,
        projection_dim,
        is_linear=False,
    ):
        super(SimpleMLP, self).__init__()
        if hidden_layers is None:
            fcs = [nn.Linear(input_embedding_dim, projection_dim)]
        else:
            fcs = [nn.Linear(input_embedding_dim, hidden_layers[0])]
            for i in range(len(hidden_layers) - 1):
                fcs.append(nn.Linear(hidden_layers[i], hidden_layers[i + 1]))
                if not is_linear:
                    fcs.append(nn.ReLU())
            fcs.append(nn.Linear(hidden_layers[-1], projection_dim))

        self.model = nn.Sequential(*fcs)

    def forward(self, x):
        return self.model(x)
