from typing import List

import torch
import torch.nn as nn
from loguru import logger
from project_utils.losses import *
from project_utils.model_utils import get_spn_mpe_output
from project_utils.profiling import pytorch_profile

bce_loss = torch.nn.BCELoss()
l2_loss = torch.nn.MSELoss()


class NeuralNetworkFour(nn.Module):
    def __init__(
        self,
        cfg,
        input_size: int,
        hidden_sizes: List[int],
        num_variables,
        library_spn,
        num_query_variables: int,
    ):
        """
        The `__init__` function initializes a neural network with specified input size, hidden layer sizes,
        and output size, and sets up the layers and activation functions.

        :param cfg: The "cfg" parameter is likely an object or dictionary that contains various
        configuration options for the neural network. It could include options such as whether to use
        dropout or batch normalization, the dropout rate, and other hyperparameters
        :param input_size: The input_size parameter represents the size of the input to the neural network.
        It specifies the number of features or dimensions in the input data
        :param hidden_sizes: The `hidden_sizes` parameter is a list that specifies the number of neurons in
        each hidden layer of the neural network. Each element in the list represents the number of neurons
        in a particular hidden layer. For example, if `hidden_sizes = [64, 128, 256]`, it means
        :param output_size: The `output_size` parameter represents the number of output units in the neural
        network. It determines the size of the output layer, which is the final layer of the network that
        produces the output predictions
        """
        super(NeuralNetworkFour, self).__init__()
        self.input_size = input_size
        self.hidden_layers = nn.ModuleList()
        self.no_dropout = cfg.no_dropout
        self.no_batchnorm = cfg.no_batchnorm
        self.num_variables = num_variables
        self.library_spn = library_spn
        output_size = num_query_variables
        self.num_query_variables = num_query_variables
        for i in range(len(hidden_sizes)):
            if i == 0:
                self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                self.hidden_layers.append(
                    nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])
                )
            self.hidden_layers.append(nn.ReLU())
            if not self.no_batchnorm:
                self.hidden_layers.append(nn.BatchNorm1d(hidden_sizes[i]))
        self.output = nn.Linear(hidden_sizes[-1], output_size)
        if not self.no_dropout:
            self.dropout = nn.Dropout(cfg.dropout_rate)
        self.cfg = cfg

    def forward(self, x):
        """
        The forward function applies a series of hidden layers to the input, with optional dropout, and
        returns the model output.

        :param x: The input to the forward function, which is the input to the neural network model. It
        could be a single data point or a batch of data points
        :return: The output of the model.
        """
        for layer in self.hidden_layers:
            x = layer(x)
            # Apply dropout
            if not self.no_dropout and isinstance(layer, nn.ReLU):
                x = self.dropout(x)
        model_output = self.output(x)
        return model_output

    def process_buckets_single_row_for_spn(self, nn_output, true, buckets):
        """
        Process the buckets based on the given sample tensor for a single row.

        cfg:
            sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
            buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input sample,
                        where the buckets have been modified according to the provided rules.
        """
        # create a tensor of size self.num_variables with gradient tracking
        num_examples = nn_output.shape[0]
        final_sample = torch.zeros(
            (num_examples, self.num_variables), device=nn_output.device
        )

        # Handle the first bucket
        # if buckets['evid'] is not torch.tensor([-1]):
        # indices = torch.nonzero(buckets['evid'])
        indices = buckets["evid"]
        final_sample[indices] = true[indices]

        # indices = torch.nonzero(buckets['unobs'])
        indices = buckets["unobs"]
        final_sample[indices] = float("nan")

        # Query Buckets
        query_indices = buckets["query"]
        # Get the row and column indices for each True value in query_indices
        rows, cols = torch.where(query_indices)

        # Flatten nn_output for direct indexing
        flattened_nn_output = nn_output.view(-1)

        # Map each row index to its corresponding values in nn_output
        # Assuming each row in nn_output corresponds to self.num_query_variables True values in the same row of query_indices
        mapped_indices = rows * self.num_query_variables + (
            cols % self.num_query_variables
        )

        # Update new_tensor with values from nn_output
        final_sample[rows, cols] = flattened_nn_output[mapped_indices]

        return final_sample

    def train_iter(
        self,
        spn,
        data,
        data_spn,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = data
        elif self.cfg.input_type == "spn":
            input_to_model = data_spn
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = torch.cat((data_spn, data), dim=1)
        else:
            raise ValueError("Input type not supported for NN model 2")
        model_output = self(input_to_model)

        if self.cfg.activation_function == "sigmoid":
            output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            output = m(model_output)
        if torch.isnan(output).any():
            logger.info(output)
            raise(ValueError("Nan in output"))


        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_spn = self.process_buckets_single_row_for_spn(
            nn_output=output, true=initial_data, buckets=buckets
        )

        final_func_value = spn.evaluate(output_for_spn)
        if not self.cfg.no_log_loss:
            loss_from_spn = -final_func_value
        else:
            loss_from_spn = -torch.exp(final_func_value)

        loss = loss_from_spn

        if self.cfg.add_supervised_loss:
            outputs_for_spn_np = output_for_spn.detach().cpu().numpy()
            query_bucket_np = query_bucket.detach().cpu().numpy()
            mpe_outputs = get_spn_mpe_output(
                self.library_spn, outputs_for_spn_np, query_bucket_np
            )
            mpe_outputs = torch.tensor(mpe_outputs, device=self.cfg.device)
            query_spn_outputs = mpe_outputs[query_bucket]
            # here output has only query variables
            query_nn_outputs = output
            supervised_loss = l2_loss(query_nn_outputs, query_spn_outputs)
            loss += self.cfg.supervised_loss_lambda * supervised_loss

        loss = loss.mean()
        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss
        return loss

    def validate_iter(
        self,
        spn,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_spn,
        all_buckets,
        data,
        data_spn,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        if self.cfg.input_type == "data":
            input_to_model = data
        elif self.cfg.input_type == "spn":
            input_to_model = data_spn
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = torch.cat((data_spn, data), dim=1)
        model_output = self(input_to_model)
        if self.cfg.activation_function == "sigmoid":
            model_output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            m = nn.Hardsigmoid()
            model_output = m(model_output)

        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_spn = self.process_buckets_single_row_for_spn(
            nn_output=model_output, true=initial_data, buckets=buckets
        )
        if torch.isnan(model_output).any():
            logger.info("Nan in output")
            logger.info(model_output)
            exit()

        final_func_value = spn.evaluate(output_for_spn)
        all_nn_outputs.extend(model_output.detach().cpu().tolist())
        all_unprocessed_data.extend(initial_data.detach().cpu().tolist())
        all_outputs_for_spn.extend(output_for_spn.detach().cpu().tolist())
        for each_bucket in buckets:
            all_buckets[each_bucket].extend(
                buckets[each_bucket].detach().cpu().tolist()
            )

        if not self.cfg.no_log_loss:
            loss = -final_func_value
        else:
            loss = -torch.exp(final_func_value)

        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(model_output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss

        loss = loss.mean()
        return loss
