from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from project_utils.losses import *


class NeuralNetworkOne(nn.Module):
    def __init__(self, cfg, input_size: int, hidden_sizes: List[int], output_size: int):
        """
        The `__init__` function initializes a neural network with specified input size, hidden layer sizes,
        and output size, and sets up the layers and activation functions.

        :param cfg: The "cfg" parameter is likely an object or dictionary that contains various
        configuration options for the neural network. It could include options such as whether to use
        dropout or batch normalization, the dropout rate, and other hyperparameters
        :param input_size: The input_size parameter represents the size of the input to the neural network.
        It specifies the number of features or dimensions in the input data
        :param hidden_sizes: The `hidden_sizes` parameter is a list that specifies the number of neurons in
        each hidden layer of the neural network. Each element in the list represents the number of neurons
        in a particular hidden layer. For example, if `hidden_sizes = [64, 128, 256]`, it means
        :param output_size: The `output_size` parameter represents the number of output units in the neural
        network. It determines the size of the output layer, which is the final layer of the network that
        produces the output predictions
        """
        super(NeuralNetworkOne, self).__init__()
        self.hidden_layers = nn.ModuleList()
        self.no_dropout = cfg.no_dropout
        self.no_batchnorm = cfg.no_batchnorm
        for i in range(len(hidden_sizes)):
            if i == 0:
                self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                self.hidden_layers.append(
                    nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])
                )
            self.hidden_layers.append(nn.ReLU())
            if not self.no_batchnorm:
                self.hidden_layers.append(nn.BatchNorm1d(hidden_sizes[i]))
        self.output = nn.Linear(hidden_sizes[-1], 2 * output_size)
        if not self.no_dropout:
            self.dropout = nn.Dropout(cfg.dropout_rate)
        self.cfg = cfg

    def forward(self, x):
        """
        The forward function applies a series of hidden layers to the input, with optional dropout, and
        returns the model output.

        :param x: The input to the forward function, which is the input to the neural network model. It
        could be a single data point or a batch of data points
        :return: The output of the model.
        """
        for layer in self.hidden_layers:
            x = layer(x)
            # Apply dropout
            if not self.no_dropout and isinstance(layer, nn.ReLU):
                x = self.dropout(x)
        model_output = self.output(x)
        return model_output

    def process_buckets_single_row_for_spn(self, nn_output, true, buckets):
        """
        Process the buckets based on the given sample tensor for a single row.

        cfg:
            sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
            buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input sample,
                        where the buckets have been modified according to the provided rules.
        """
        # Index 1 has the probability of that variable being 1
        # Clone allows for gradient computation
        # You should use clone() to get a new Tensor with the same value but that is backed by new memory.
        # The tutorial uses it because it later modifies the Tensor inplace and it is forbidden to modify the gradient given to you inplace. So it first clone it to get new memory. Then the inplace change won’t break that rule.
        # Thus clone allows us to modify the tensor inplace without breaking the rule
        final_sample = nn_output.clone().squeeze().requires_grad_(True)
        # Handle the first bucket
        # if buckets['evid'] is not torch.tensor([-1]):
        # indices = torch.nonzero(buckets['evid'])
        indices = buckets["evid"]
        # true is 2d we need to make it 3d so that we make final_sample as 1 for the 3rd index corresponding to true[indices]
        # Thus if i,j in true is 0, then we make final_sample[i,j,0] = 1
        # If i,j in true is 1, then we make final_sample[i,j,1] = 1
        final_sample[indices, true[indices].clone().long()] = 1

        # Query is already present
        # indices = torch.nonzero(buckets['unobs'])
        indices = buckets["unobs"]
        # For this make both 0 and 1 in third dimension as float("nan")
        final_sample[indices, :] = float("nan")

        return final_sample

    def train_iter(
        self,
        pgm_loss,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        num_examples = data.shape[0]
        if self.cfg.input_type == "data":
            input_to_model = data
        elif self.cfg.input_type == "spn":
            input_to_model = data_pgm
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = torch.cat((data_pgm, data), dim=1)
        model_output = self(input_to_model)
        # l2_loss = torch.nn.MSELoss()
        # Reshape the tensor to have pairs of elements in each row, keeping the batch dimension
        model_output_reshaped = model_output.view(num_examples, -1, 2)

        # Apply softmax to each pair in each example
        softmax_output = F.softmax(model_output_reshaped, dim=2)
        if torch.isnan(softmax_output).any():
            logger.info("Nan in output")
            logger.info(softmax_output)
            exit()

        # softmax_output is of shape (batch_size, num_features, 2)
        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_spn(
            nn_output=softmax_output, true=initial_data, buckets=buckets
        )

        final_func_value = pgm_loss.evaluate(output_for_pgm)
        if not self.cfg.no_log_loss:
            loss_from_pgm = -final_func_value
        else:
            loss_from_pgm = -torch.exp(final_func_value)

        loss = loss_from_pgm

        if self.cfg.add_supervised_loss:
            query_indices = query_bucket

        if self.cfg.add_evid_loss:
            # Get the indices where evid_bucket is True
            evid_indices = torch.nonzero(evid_bucket, as_tuple=True)

            # Index softmax_output using these indices
            selected_softmax_output = softmax_output[evid_indices[0], evid_indices[1]]

            # Your labels tensor should also be indexed with the same indices and converted to LongTensor
            selected_labels = initial_data[evid_indices].type(torch.LongTensor)

            # Move the tensors to the GPU
            selected_softmax_output = selected_softmax_output.cuda()
            selected_labels = selected_labels.cuda()

            # Define the loss function
            ce_loss = nn.NLLLoss()

            # Calculate the loss
            evid_loss = ce_loss(selected_softmax_output, selected_labels)
            loss += self.cfg.evid_lambda * evid_loss

        loss = loss.mean()
        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(
                softmax_output[:, :, 1], self.cfg.entropy_lambda
            )
            loss = loss + entropy_loss
        return loss

    def validate_iter(
        self,
        pgm,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_pgm,
        all_buckets,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        num_examples = data.shape[0]
        if self.cfg.input_type == "data":
            input_to_model = data
        elif self.cfg.input_type == "spn":
            input_to_model = data_pgm
        elif self.cfg.input_type == "dataSpn":
            # First half is spn data, second half is data
            input_to_model = torch.cat((data_pgm, data), dim=1)
        else:
            raise ValueError("Input type not supported")
        model_output = self(input_to_model)
        # Reshape the tensor to have pairs of elements in each row, keeping the batch dimension
        model_output_reshaped = model_output.view(num_examples, -1, 2)

        # Apply softmax to each pair in each example
        softmax_output = F.softmax(model_output_reshaped, dim=2)
        if torch.isnan(softmax_output).any():
            logger.info("Nan in output")
            logger.info(softmax_output)
            exit()
        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}

        output_for_pgm = self.process_buckets_single_row_for_spn(
            nn_output=softmax_output, true=initial_data, buckets=buckets
        )
        final_func_value = pgm.evaluate(output_for_pgm)
        all_nn_outputs.extend(softmax_output.detach().cpu().tolist())
        all_unprocessed_data.extend(initial_data.detach().cpu().tolist())
        # Take the value corresponding to the x = 1
        all_outputs_for_pgm.extend(output_for_pgm[:, :, 1].detach().cpu().tolist())

        for each_bucket in buckets:
            all_buckets[each_bucket].extend(
                buckets[each_bucket].detach().cpu().tolist()
            )

        if self.cfg.add_supervised_loss:
            pass

        if not self.cfg.no_log_loss:
            loss = -final_func_value
        else:
            loss = -torch.exp(final_func_value)

        if self.cfg.add_entropy_loss:
            entropy_loss = entropy_loss_function(model_output, self.cfg.entropy_lambda)
            loss = loss + entropy_loss

        loss = loss.mean()
        return loss
