from typing import List

import torch
import torch.nn as nn
from loguru import logger
from models.embedding_model.ae import Autoencoder
from project_utils.losses import *
from project_utils.model_utils import get_spn_mpe_output

bce_loss = torch.nn.BCELoss()
l2_loss = torch.nn.MSELoss()


class NeuralNetworkThree(nn.Module):
    def __init__(
        self,
        cfg,
        input_feature_size: int,
        pgm_feature_size: int,
        hidden_sizes: List[int],
        output_size: int,
        embedding_model_path: str,
        library_pgm,
    ):
        """
        The `__init__` function initializes a neural network with specified input size, hidden layer sizes,
        and output size, and sets up the layers and activation functions.

        :param cfg: The "cfg" parameter is likely an object or dictionary that contains various
        configuration options for the neural network. It could include options such as whether to use
        dropout or batch normalization, the dropout rate, and other hyperparameters
        :param input_size: The input_size parameter represents the size of the input to the neural network.
        It specifies the number of features or dimensions in the input data
        :param hidden_sizes: The `hidden_sizes` parameter is a list that specifies the number of neurons in
        each hidden layer of the neural network. Each element in the list represents the number of neurons
        in a particular hidden layer. For example, if `hidden_sizes = [64, 128, 256]`, it means
        :param output_size: The `output_size` parameter represents the number of output units in the neural
        network. It determines the size of the output layer, which is the final layer of the network that
        produces the output predictions
        """
        super(NeuralNetworkThree, self).__init__()
        self.hidden_layers = nn.ModuleList()
        self.no_dropout = cfg.no_dropout
        self.no_batchnorm = cfg.no_batchnorm
        self.embedding_model = Autoencoder(input_feature_size, hidden_layers=4)
        self.embedding_model.load_state_dict(torch.load(embedding_model_path))
        encoding_size = self.embedding_model.encoding_size
        self.library_pgm = library_pgm
        self.cfg = cfg
        if self.cfg.input_type == "data":
            input_size = encoding_size
        elif self.cfg.input_type == "dataSpn":
            input_size = encoding_size + pgm_feature_size
        else:
            raise ValueError("Input type not supported for NN model 3")

        for i in range(len(hidden_sizes)):
            if i == 0:
                self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                self.hidden_layers.append(
                    nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])
                )
            self.hidden_layers.append(nn.ReLU())
            if not self.no_batchnorm:
                self.hidden_layers.append(nn.BatchNorm1d(hidden_sizes[i]))
        self.output = nn.Linear(hidden_sizes[-1], output_size)
        if not self.no_dropout:
            self.dropout = nn.Dropout(cfg.dropout_rate)
        self.cfg = cfg

    def forward(self, x):
        """
        The forward function applies a series of hidden layers to the input, with optional dropout, and
        returns the model output.

        :param x: The input to the forward function, which is the input to the neural network model. It
        could be a single data point or a batch of data points
        :return: The output of the model.
        """
        pgm_input, x = x
        x = self.embedding_model.encode(x)
        if pgm_input is not None:
            # First half is pgm data, second half is data
            x = torch.cat((pgm_input, x), dim=1)
        for layer in self.hidden_layers:
            x = layer(x)
            # Apply dropout
            if not self.no_dropout and isinstance(layer, nn.ReLU):
                x = self.dropout(x)
        model_output = self.output(x)
        return model_output

    def process_buckets_single_row_for_pgm(self, nn_output, true, buckets):
        """
        Process the buckets based on the given sample tensor for a single row.

        cfg:
            sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
            buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input sample,
                        where the buckets have been modified according to the provided rules.
        """

        final_sample = nn_output.clone().requires_grad_(True)
        # Handle the first bucket
        # if buckets['evid'] is not torch.tensor([-1]):
        # indices = torch.nonzero(buckets['evid'])
        indices = buckets["evid"]
        final_sample[indices] = true[indices]

        # Query is already present
        # indices = torch.nonzero(buckets['unobs'])
        indices = buckets["unobs"]
        final_sample[indices] = float("nan")
        return final_sample

    def train_iter(
        self,
        pgm,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = (None, data)
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = (data_pgm, data)
        else:
            raise ValueError("Input type not supported for NN model 3")
        model_output = self(input_to_model)

        if self.cfg.activation_function == "sigmoid":
            output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            output = m(model_output)
        if torch.isnan(output).any():
            logger.info(output)
            raise (ValueError("Nan in output"))

        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_pgm(
            nn_output=output, true=initial_data, buckets=buckets
        )

        final_func_value = pgm.evaluate(output_for_pgm)
        if not self.cfg.no_log_loss:
            loss_from_pgm = -final_func_value
        else:
            loss_from_pgm = -torch.exp(final_func_value)

        loss = loss_from_pgm

        if self.cfg.add_supervised_loss:
            outputs_for_pgm_np = output_for_pgm.detach().cpu().numpy()
            query_bucket_np = query_bucket.detach().cpu().numpy()
            mpe_outputs = get_spn_mpe_output(
                self.library_pgm, outputs_for_pgm_np, query_bucket_np
            )
            mpe_outputs = torch.tensor(mpe_outputs, device=self.cfg.device)
            query_pgm_outputs = mpe_outputs[query_bucket]
            query_nn_outputs = output[query_bucket]
            supervised_loss = l2_loss(query_nn_outputs, query_pgm_outputs)
            loss += self.cfg.supervised_loss_lambda * supervised_loss

        if self.cfg.add_evid_loss:
            evidence_output = output[evid_bucket]
            evidence_true = initial_data[evid_bucket]
            evid_loss = l2_loss(evidence_output, evidence_true)
            loss += self.cfg.evid_lambda * evid_loss
        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss
        loss = loss.mean()
        return loss

    def validate_iter(
        self,
        pgm,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_pgm,
        all_buckets,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = (None, data)
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = (data_pgm, data)
        else:
            raise ValueError("Input type not supported for NN model 3")
        model_output = self(input_to_model)

        if self.cfg.activation_function == "sigmoid":
            model_output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            model_output = m(model_output)

        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_pgm(
            nn_output=model_output, true=initial_data, buckets=buckets
        )
        if torch.isnan(model_output).any():
            logger.info("Nan in output")
            logger.info(model_output)
            exit()

        final_func_value = pgm.evaluate(output_for_pgm)
        all_nn_outputs.extend(model_output.detach().cpu().tolist())
        all_unprocessed_data.extend(initial_data.detach().cpu().tolist())
        all_outputs_for_pgm.extend(output_for_pgm.detach().cpu().tolist())
        for each_bucket in buckets:
            all_buckets[each_bucket].extend(
                buckets[each_bucket].detach().cpu().tolist()
            )

        if not self.cfg.no_log_loss:
            loss = -final_func_value
        else:
            loss = -torch.exp(final_func_value)

        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(model_output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss

        loss = loss.mean()
        return loss
