from torch import nn
from src.models.encoders.linear import LinearEncoder
from src.models.encoders.mlp import MLPEncoder
from src.utils.expression_utils import store_eq
import torch

class BlackBoxPredictor(nn.Module):
    def __init__(self, 
                 memory_size, 
                 c_names,
                 output_size,
                 activation,
                 latent_size,
                 linear=False
                ):
        super(BlackBoxPredictor, self).__init__()

        self.memory_size = memory_size
        self.c_names = c_names
        self.output_size = output_size
        self.show_explanations = False
        self.activation = activation
        self.latent_size = latent_size
        self.linear = linear

        # Create a module of mlp/linear encoders for each memory slot
        memory_of_predictors = nn.ModuleList()
        for _ in range(memory_size):
            if self.linear:
                layer = LinearEncoder(
                    input_size=c_names,
                    output_size=output_size,
                    activation=activation,
                )
            else:
                layer = MLPEncoder(
                    input_size=c_names,
                    output_size=output_size,
                    hidden_size=latent_size,
                    activation=activation,
                    num_layers=1, # one hidden layer
                )
            memory_of_predictors.append(layer)
        self.memory_of_predictors = memory_of_predictors

    def _get_explanations(self, prob_per_classifier, y_hat):
        # TODO: to be completed once the kan implementation is stable
        if self.show_explanations:
            if not self.equations_for_explanations_ready:
                self._setup_string_equations()
            explanations = None
        else:
            explanations = [None] * prob_per_classifier.size(0)

        return explanations

    def forward(self, prob_per_classifier, input_concepts):
        bsz = input_concepts.shape[0]
        eq_outputs = []
        for _, layer in enumerate(self.memory_of_predictors):
            eq_outputs.append(
                layer(input_concepts)
            )
            
        # Stack the outputs along the class dimension
        eq_outputs = torch.stack(eq_outputs, dim=1) # shape: (bsz, memory_size, n_targets)

        # With independent outputs, prob_per_classifier has shape (bsz, n_outputs, memory_size, n_samples)
        # and eq_outputs has shape (bsz, memory_size, n_outputs)
        y_hat = torch.einsum('bmys,bmy->bys', prob_per_classifier, eq_outputs)

        return {
            'y_hat': y_hat,
        }