import torch.nn as nn
import torch_concepts.nn as pyc_nn
from src.models.baselines.base import BaseModel
from src.utils.expression_utils import kan_expression, store_eq
from src.models.modules.selector import SelectorModel
from src.models.modules.kan_predictor import KANPredictor
from src.models.modules.symbolic_predictor import SymbolicPredictor
import os

class KANSymbolicCBM(BaseModel):
    def __init__(self, 
                 output_size,
                 c_names,
                 y_names,
                 task, 
                 task_penalty,
                 activation='ReLU',
                 int_prob=0.1,
                 int_idxs=None,
                 noise=None,
                 embedding_size=16,
                 latent_size=128,
                 c_groups=None,
                 memory_size=7,
                 hard_concepts=False,
                 encoder=None,
                 mc_approx=1,
                 selector_model='linear',
                 concept_loss_form=nn.BCELoss(),
                 backbone_latent_size=None,
                 concept_type='binary',
                 known_equations=None,
                 disjoint_training=False,
                 decay_rate='cosine',
                 embedding_memory=False,
                 concept_penalty=1.0,
                 regularize=True,
                 widths=None,
                 device='cpu',
                 speed_up_training=True,
                 **kwargs
                 ):

        super().__init__(
            output_size,
            c_names,
            y_names,
            task,
            task_penalty,
            hard_concepts,
            activation,
            int_prob,
            int_idxs,
            noise,
            latent_size,
            c_groups,
            encoder,
            backbone_latent_size,
            concept_type,
            disjoint_training,
            concept_penalty
        )

        self.embedding_size = embedding_size
        self.has_concepts = True
        self.y_names = list(y_names)
        self.output_size = output_size
        self.backbone_latent_size = backbone_latent_size
        self.activation = activation
        self.embedding_memory = embedding_memory
        self.show_explanations = False
        self.equations_for_explanations_ready = False
        self.regularize = regularize
        self.symbolic_predictors = False
        self.device = device
        self.speed_up_training = speed_up_training

        if widths is not None:
            self.widths = widths
        else:
            if self.output_size == 1:
                # Approach suggested by the authors of KAN
                self.widths = [
                    len(self.c_names), 
                    len(self.c_names) + 1,
                    len(self.c_names) + 1,
                    self.output_size]
            else:
                # For multi-output tasks, we use a smaller architecture since the number of parameters 
                # grows quickly with the number of outputs and this slows down the auto-symbolic search.
                # NOTE: this is a design choice made according to the hardware and different architectures can be used.
                self.widths = [len(self.c_names), self.output_size]

        grid_size = 5
        k = 3
        auto_save = False

        self.mc_approx = mc_approx
        self.memory_size = memory_size

        # Instantiate the selector
        self.classifier_selector = SelectorModel(
            input_size=self.backbone_latent_size,
            output_size=self.memory_size,
            n_outputs=self.output_size,
            model_type=selector_model,
            activation=activation,
            decay_rate=decay_rate,
        )
        
        self.bottleneck = pyc_nn.LinearConceptBottleneck(
            backbone_latent_size,
            self.c_names,
            activation=nn.Identity(), # we will later apply a sigmoid if the concept is boolean
        )

        # KAN predictor instantiation
        self.predictor = KANPredictor(
            widths=self.widths,
            grid=grid_size,
            k=k,
            memory_size=self.memory_size,
            device=self.device,
            speed_up_training=self.speed_up_training,
            auto_save=auto_save,
            y_names=self.y_names,
            regularize=self.regularize,
        )

    def setup_kan_grid(self, grid_inputs):
        self.predictor.setup_kan_grid(grid_inputs)

    def prune(self):
        self.predictor.prune()
        
    def allow_symbolic_extraction(self):
        self.predictor.allow_symbolic()
    
    def get_learned_equations(self, log_dir, fine_tuned=False):
        if not fine_tuned:
            # self.predictor.get_learned_equations(log_dir)
            # Get dictionary of learned equations
            equations, variable_names = self.predictor.get_learned_equations(log_dir)

            # If c_names is not none, substitute the variable names with c_names
            if self.c_names is not None and self.y_names is not None:
                substituted_equations = {}
                for out_name, expr_dict in equations.items():
                    substituted_equations[out_name] = {}
                    for idx, (eq_id, expr) in enumerate(expr_dict.items()):
                        substituted_expr = expr
                        for var_name, c_name in zip(variable_names, self.c_names):
                            substituted_expr = substituted_expr.subs(var_name, c_name)
                        substituted_equations[out_name][self.y_names[idx]] = substituted_expr
                equations = substituted_equations

            # Instantiate the symbolic predictor
            self.predictor = SymbolicPredictor(
                equations=equations,
                c_names=self.c_names
            )
            self.symbolic_predictors = True
        
        self.predictor.get_learned_equations(log_dir, fine_tuned=fine_tuned)

    ###### Forward and loss methods ######
    def forward(self, input):

        latent, x_concepts, c_true, int_idxs = self.encode(input)

        ## Concept encoder and concept processing block ##
        c_hat, _ = self.bottleneck(x_concepts)

        c_hat, input_concepts = self._process_concepts(c_hat, c_true, int_idxs)

        ## Selector block ##
        selector_output = self.classifier_selector(latent, global_step=self.global_step)
        selector_probs = selector_output['selector_probs'] # [batch_size, memory_size, n_samples]
        selection_dist = selector_output['selection_dist']

        ## Equation execution block ##
        predictor_output = self.predictor(selector_probs, input_concepts)

        return {
            'y_hat': predictor_output['y_hat'],
            'c_hat': c_hat,
            'selection_dist': selection_dist
        }

    def loss(self, y_hat, y, c_hat=None, c=None, *args, **kwargs):
        loss = self.concept_based_loss(y_hat, y, c_hat, c)
        if not self.symbolic_predictors:
            # KAN regularization: it promotes sparsity in the KAN layers
            loss += self.predictor.regularization_term()
        return loss
    
    def get_symbolic_equivalent(self, log_dir=None):
        """
        Returns the equation associated to the KAN predictor of the model
        """

        # # Remove the last element in widths and substitute with 1
        # # This is because we want to get the symbolic expression for a single output (as we did for the other models).
        # try:
        #     single_output_widths = self.widths[:-1] + [1]
        #     equation = kan_expression(single_output_widths)
        # except ValueError:
        #     print(f"Zeros are appended to the each element in widths, we need to remove them")
        #     self.widths = [w[0] for w in self.widths]
        #     single_output_widths = self.widths[:-1] + [1]
        #     equation = kan_expression(single_output_widths)

        single_output_widths = self.widths[:-1] + [[1, 0]]
        equation = kan_expression(single_output_widths)

        # Generate the abstract (operators are not defined) symbolic equivalent of the kan used by the model.
        store_eq(equation, log_dir)

        # Store equations for each memory slot
        if log_dir is not None:
            memory_eq_dir = os.path.join(log_dir, "memory_slots")
            os.makedirs(memory_eq_dir, exist_ok=True)
            self._store_memory_equations(memory_eq_dir)

    def _store_memory_equations(self, dir):
        """
        Store the equations associated to each memory slot.
        """
        import os
        
        # Check if predictor has symbolic equations
        if hasattr(self.predictor, 'trainable_equations'):
            # SymbolicPredictor - store learned equations
            for mem_idx, set_name in enumerate(sorted(self.predictor.trainable_equations.keys())):
                mem_dir = os.path.join(dir, f"memory_slot_{mem_idx}")
                os.makedirs(mem_dir, exist_ok=True)
                
                # Create text file for this memory slot
                text_file = os.path.join(mem_dir, "equations.txt")
                with open(text_file, "w") as f:
                    f.write(f"Memory Slot {mem_idx} (Set: {set_name})\n")
                    f.write("=" * 60 + "\n\n")
                    
                    # Store each equation in this memory slot
                    for eq_idx, eq_name in enumerate(self.predictor.equation_names[set_name]):
                        eq_module = self.predictor.trainable_equations[set_name][eq_name]
                        
                        # Get the equation expression
                        equation_expr = eq_module.sympy_expr
                        
                        # Store in pickle format
                        store_eq(equation_expr, mem_dir, idx=eq_idx)
                        
                        # Write to text file
                        f.write(f"Equation {eq_idx} ({eq_name}):\n")
                        f.write(f"  Expression: {equation_expr}\n")
                        f.write(f"  Parameters: {eq_module.get_param_values()}\n")
                        f.write(f"  Current form: {eq_module.get_equation_string()}\n")
                        f.write("\n")
        
        elif hasattr(self.predictor, 'kans'):
            # KANPredictor - store abstract KAN structure for each memory slot
            for mem_idx, kan_layer in enumerate(self.predictor.kans):
                mem_dir = os.path.join(dir, f"memory_slot_{mem_idx}")
                os.makedirs(mem_dir, exist_ok=True)
                
                # Generate the abstract KAN equation for this memory slot
                equation = kan_expression(self.widths)
                
                # Store in pickle format
                store_eq(equation, mem_dir, idx=0)
                
                # Create text file for this memory slot
                text_file = os.path.join(mem_dir, "equations.txt")
                with open(text_file, "w") as f:
                    f.write(f"Memory Slot {mem_idx}\n")
                    f.write("=" * 60 + "\n\n")
                    f.write(f"KAN Architecture: {self.widths}\n")
                    f.write(f"Abstract KAN Expression:\n{equation}\n\n")
                    
                    # If the KAN has been symbolified, try to get the actual equations
                    if hasattr(kan_layer, 'symbolic_fun') and kan_layer.symbolic_fun is not None:
                        try:
                            symbolic_output = kan_layer.symbolic_formula()
                            learned_equations = symbolic_output[0]
                            variable_names = symbolic_output[1]
                            
                            f.write("Learned Symbolic Equations:\n")
                            for eq_idx, (eq, y_name) in enumerate(zip(learned_equations, self.y_names)):
                                f.write(f"  {y_name}: {eq}\n")
                                # Store each learned equation as pickle
                                store_eq(eq, mem_dir, idx=eq_idx + 1)
                        except Exception as e:
                            f.write(f"Could not extract symbolic formulas: {e}\n")
                    else:
                        f.write("KAN has not been symbolified yet.\n")
        else:
            # Unknown predictor type
            no_equations_file = os.path.join(dir, "no_equations.txt")
            with open(no_equations_file, "w") as f:
                f.write("No equations available in memory yet.\n")
                f.write(f"Predictor type: {type(self.predictor).__name__}\n")
