from typing import List

import torch
import torch.nn as nn
from loguru import logger
from models.ste import STEFunction
from project_utils.losses import *
from project_utils.model_utils import get_spn_mpe_output
from project_utils.profiling import pytorch_profile


class NeuralNetworkTwo(nn.Module):
    def __init__(
        self,
        cfg,
        input_size: int,
        hidden_sizes: List[int],
        output_size: int,
        supervised_loss_lambda: float,
        library_pgm,
        is_teacher: bool = False,
    ):
        """
        Initializes a neural network with configurable layers.

        :param cfg: Configuration object containing network options.
        :param input_size: Number of features in the input data.
        :param hidden_sizes: List of neuron counts for each hidden layer.
        :param output_size: Number of neurons in the output layer.
        :param library_pgm: External library or tool specific parameter.
        """
        super(NeuralNetworkTwo, self).__init__()
        self.library_pgm = library_pgm
        self.cfg = cfg
        self.num_hidden_layers = (
            cfg.student_layers if not is_teacher else cfg.teacher_layers
        )
        self._set_activation_func(cfg)
        if is_teacher:
            logger.info(f"Initializing teacher model with {hidden_sizes} layers")
            logger.info("No regularization for teacher model")
        # Layer definition
        layers = []
        last_hidden_layer_size = self.init_hidden_layers(
            cfg,
            input_size,
            hidden_sizes,
            output_size,
            is_teacher,
            layers,
        )
        if self.num_hidden_layers > -1:
            self.output_layer = nn.Linear(last_hidden_layer_size, output_size)
        elif self.num_hidden_layers == -1:
            self.output_layer = nn.Parameter(torch.randn(output_size))
            self.input_values = torch.ones(
                1,
                self.output_layer.shape[0],
                dtype=float,
                device=cfg.data_device,
            )
        else:
            raise ValueError("Invalid number of hidden layers")
        if not is_teacher:
            self.dropout = nn.Dropout(cfg.dropout_rate) if not cfg.no_dropout else None
            self.no_dropout = cfg.no_dropout
        else:
            self.dropout = None
            self.no_dropout = True
        self.ste = STEFunction.apply if cfg.use_ste else None
        self.supervised_loss_lambda = supervised_loss_lambda
        self.device = cfg.device
        self.initialize_weights()
        if cfg.dual_network:
            self.loss_dn_name = cfg.loss_dn
            if cfg.loss_dn == "mse":
                self.train_loss = nn.MSELoss()
            elif cfg.loss_dn == "bce":
                self.train_loss = nn.BCEWithLogitsLoss()
            else:
                raise ValueError("Loss not supported for Student Network")

    def update_output_layer(self, updated_output_layer):
        with torch.no_grad():
            self.output_layer.copy_(updated_output_layer)

    def init_hidden_layers(
        self,
        cfg,
        input_size,
        hidden_sizes,
        output_size,
        is_teacher,
        layers,
    ):
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(
                nn.Linear(input_size if i == 0 else hidden_sizes[i - 1], hidden_size)
            )
            layers.append(self.hidden_activation())
            if not cfg.no_batchnorm and not is_teacher:
                layers.append(nn.BatchNorm1d(hidden_size))

        if self.num_hidden_layers == 0:
            # Initialize a LR model
            # layers.append(nn.Linear(input_size, output_size))
            last_hidden_layer_size = input_size
        elif self.num_hidden_layers == -1:
            # Model without any connections
            # Initialize n weights
            last_hidden_layer_size = -1
        if self.num_hidden_layers > 0:
            self.hidden_layers = nn.Sequential(*layers)
            last_hidden_layer_size = hidden_sizes[-1]
        return last_hidden_layer_size

    def _set_activation_func(self, cfg):
        self.hidden_activation_function = cfg.hidden_activation_function
        if cfg.hidden_activation_function == "relu":
            self.hidden_activation = nn.ReLU
        elif cfg.hidden_activation_function == "leaky_relu":
            self.hidden_activation = nn.LeakyReLU

    def initialize_weights(self):
        """
        Initializes the weights of the neural network.
        """
        # check if the model has hidden layers
        if self.num_hidden_layers > 0:
            for layer in self.hidden_layers:
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_normal_(
                        layer.weight,
                        mode="fan_in",
                        nonlinearity=self.hidden_activation_function,
                    )
                    nn.init.zeros_(layer.bias)
        if self.num_hidden_layers > -1:
            nn.init.xavier_uniform_(self.output_layer.weight)
            nn.init.zeros_(self.output_layer.bias)

    def get_features(self, x, layer_index):
        """
        Get the features from the hidden layer specified by the index.

        :param x: The input to the neural network model.
        :param layer_index: The index of the hidden layer from which to extract the features.
        :return: The features extracted from the hidden layer.
        """
        i = 0
        for layer in self.hidden_layers:
            x = layer(x)
            if i == layer_index:
                return x
            # do i++ if the layer is a linear layer
            if isinstance(layer, nn.Linear):
                i += 1
        return x

    def forward(self, x):
        """
        The forward function applies a series of hidden layers to the input, with optional dropout, and
        returns the model output.

        :param x: The input to the forward function, which is the input to the neural network model. It
        could be a single data point or a batch of data points
        :return: The output of the model.
        """
        if self.num_hidden_layers > 0:
            for layer in self.hidden_layers:
                x = layer(x)
                # Apply dropout
                if not self.no_dropout and isinstance(layer, nn.ReLU):
                    x = self.dropout(x)
            model_output = self.output_layer(x)
        elif self.num_hidden_layers == 0:
            model_output = self.output_layer(x)
        elif self.num_hidden_layers == -1:
            # for this case we treat the weights as the inputs and optimize the loss over those (works like a LP)
            model_output = self.input_values * self.output_layer
        else:
            raise ValueError("Invalid number of hidden layers")
        return model_output

    def process_buckets_single_row_for_pgm(self, nn_output, true, buckets):
        """
        Process the buckets based on the given sample tensor for a single row.

        cfg:
            sample (torch.Tensor): Input tensor of shape (n_vars,) containing binary values.
            buckets (list): List of bucket indices where each bucket is represented by a list of variable indices.

        Returns:
            torch.Tensor: Processed tensor of the same shape as the input sample,
                        where the buckets have been modified according to the provided rules.
        """

        final_sample = nn_output.clone().requires_grad_(True)
        # Handle the first bucket
        # Modify the tensor based on 'evid' and 'unobs' buckets
        evid_indices = buckets["evid"]
        unobs_indices = buckets["unobs"]

        final_sample[evid_indices] = true[evid_indices]
        final_sample[unobs_indices] = float("nan")
        return final_sample

    def train_iter(
        self,
        pgm,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
        return_mean=True,
    ):
        input_to_model = self._select_input(data, data_pgm)
        model_output = self(input_to_model)
        output = self._apply_activation(model_output)

        if torch.isnan(output).any():
            logger.info(output)
            raise (ValueError("Nan in output"))

        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_pgm(
            nn_output=output, true=initial_data, buckets=buckets
        )

        loss = self._calculate_loss(pgm, output_for_pgm, output, initial_data, buckets)
        if return_mean:
            return loss.mean()
        return loss

    def supervised_train_iter(
        self,
        pgm,
        psuedo_labels,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
    ):
        input_to_model = self._select_input(data, data_pgm)
        model_output = self(input_to_model)
        if self.loss_dn_name == "mse":
            output = self._apply_activation(model_output)
        else:
            output = model_output
        if torch.isnan(output).any():
            logger.info(output)
            raise (ValueError("Nan in output"))
        sup_loss = self._calculate_student_loss(output, psuedo_labels, query_bucket)
        return sup_loss.mean()

    @torch.no_grad()
    def supervised_validate_iter(
        self,
        pgm,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
        return_mean=True,
    ):
        input_to_model = self._select_input(data, data_pgm)
        model_output = self(input_to_model)
        model_output = self._apply_activation(model_output)

        if torch.isnan(model_output).any():
            logger.info("Nan in output")
            logger.info(model_output)
            exit()
        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_pgm(
            nn_output=model_output, true=initial_data, buckets=buckets
        )

        loss = self._calculate_loss(
            pgm, output_for_pgm, model_output, initial_data, buckets
        )
        if return_mean:
            loss = (loss.mean(),)
        return loss, output_for_pgm, initial_data, buckets

    def validate_iter(
        self,
        pgm,
        all_unprocessed_data,
        all_nn_outputs,
        all_outputs_for_pgm,
        all_buckets,
        data,
        data_pgm,
        initial_data,
        evid_bucket,
        query_bucket,
        unobs_bucket,
        attention_mask,
        return_mean=True,
    ):
        input_to_model = self._select_input(data, data_pgm)
        model_output = self(input_to_model)
        model_output = self._apply_activation(model_output)

        if torch.isnan(model_output).any():
            logger.info("Nan in output")
            logger.info(model_output)
            exit()
        buckets = {"evid": evid_bucket, "query": query_bucket, "unobs": unobs_bucket}
        output_for_pgm = self.process_buckets_single_row_for_pgm(
            nn_output=model_output, true=initial_data, buckets=buckets
        )

        loss = self._calculate_loss(
            pgm, output_for_pgm, model_output, initial_data, buckets
        )
        all_nn_outputs.extend(model_output.detach().cpu().tolist())
        all_unprocessed_data.extend(initial_data.detach().cpu().tolist())
        all_outputs_for_pgm.extend(output_for_pgm.detach().cpu().tolist())
        for each_bucket in buckets:
            all_buckets[each_bucket].extend(
                buckets[each_bucket].detach().cpu().tolist()
            )
        if return_mean:
            return loss.mean()
        else:
            return loss

    def _select_input(self, data, data_pgm):
        if self.cfg.input_type == "data":
            return data
        elif self.cfg.input_type == "spn":
            return data_pgm
        elif self.cfg.input_type == "dataSpn":
            return torch.cat((data_pgm, data), dim=1)
        else:
            raise ValueError("Input type not supported for NN model 2")

    def _apply_activation(self, model_output):
        if self.cfg.activation_function == "sigmoid":
            model_output = torch.sigmoid(model_output)
        elif self.cfg.activation_function == "hard_sigmoid":
            model_output = nn.Hardsigmoid()(model_output)
        else:
            raise ValueError("Activation function not supported for NN model 2")
        if self.cfg.use_ste:
            model_output = self.ste(model_output)
        return model_output

    def _calculate_student_loss(self, output, psuedo_labels, query_bucket):
        query_output = output[query_bucket]
        query_psuedo_labels = psuedo_labels[query_bucket]
        sup_loss = self.train_loss(query_output, query_psuedo_labels)

        if self.cfg.add_supervised_loss:
            sup_loss += supervised_loss(
                output,
                output,
                query_bucket,
                self.library_pgm,
                self.cfg.supervised_loss_lambda,
                self.device,
            )

        if self.cfg.add_entropy_loss:
            sup_loss += entropy_loss(output, self.cfg.entropy_lambda)

        return sup_loss

    def _calculate_loss(self, pgm, output_for_pgm, output, initial_data, buckets):
        query_bucket, evid_bucket = buckets["query"], buckets["evid"]
        final_func_value = pgm.evaluate(output_for_pgm)
        loss_from_pgm = (
            -final_func_value
            if not self.cfg.no_log_loss
            else -torch.exp(final_func_value)
        )
        loss = loss_from_pgm

        if self.cfg.add_supervised_loss:
            loss += supervised_loss(
                output_for_pgm,
                output,
                query_bucket,
                self.library_pgm,
                self.cfg.supervised_loss_lambda,
                self.device,
            )

        if self.cfg.add_distance_loss_evid_ll:
            loss += distance_loss(
                pgm,
                output_for_pgm,
                initial_data,
                buckets,
                self.cfg.evid_lambda,
            )

        if self.cfg.add_evid_loss:
            loss += evid_loss(output, evid_bucket, initial_data, self.evid_lambda)

        if self.cfg.add_entropy_loss:
            loss += entropy_loss(output, self.cfg.entropy_lambda)

        return loss
