from __future__ import annotations
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Tuple, Optional
from model.base import Siren
from .base_net import BaseNet, SinePE, extract_mesh
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from .trajectory_estimator_stable import StableTrajectoryEstimator
OMEGA_0 = 30  # Constant for Siren initialization

class InverseEncoder(BaseNet):
    """
    InverseEncoder class for inferring latent age (t) from deformations.

    Inherits:
        BaseNet: Provides positional encoding and input preparation utilities.

    Attributes:
        teacher (StudentForward_v4): Pre-trained teacher model for deformation prediction.
        network (modules.ResidualSiren): Backbone network for encoding deformations and spatial coordinates.
    """
    def __init__(self, cfg, teacher: StableTrajectoryEstimator):
        super().__init__(cfg)
        logger.info("Initializing InverseEncoder...")

        # Validate configuration
        required_keys = ["InFeatures", "HiddenFeatures", "HiddenLayers", "CovariateNames"]
        self.validate_cfg(cfg, required_keys)

        # Extract configuration parameters  
        self.covariate_names = cfg.get("CovariateNames", [])
        # Initialize attributes for compatibility with Student_v8
        self.dict_idx_cov = {i: cov for i, cov in enumerate(self.covariate_names)}
        self.dict_cov_idx = {cov: i for i, cov in enumerate(self.covariate_names)}

        # Read positional encoding configuration for InverseEncoder
        pe_config = cfg.get("PosEncConfig", {})
        
        self.pos_encoder = SinePE(
            num_encoding_functions=pe_config["pe_dims"],
            include_input=pe_config["include_input"],
            log_sampling=pe_config["log_sampling"]
        ) if pe_config['enabled'] else nn.Identity()

        in_dims_for_inverse_encoder = self.in_dim * 2
        

        list_covariates       = cfg.get("CovariateNames", [])
        output_dim = len(list_covariates)  # Predict latent age t
        hidden_features       = cfg.get("HiddenFeatures",      512)
        hidden_layers         = cfg.get("HiddenLayers",         6)
        self.batch_size = cfg.get("BatchSize", 0)
        self.num_of_points = cfg.get("SamplesPerScene", 0)
        
        # Teacher model (StudentForward_v0)
        self.teacher = teacher
        #self._freeze_teacher()
        self.sampling_range = 1.2 if self.in_dim == 3 else 2.0

        # Backbone network
        self.network = Siren(
            in_features=in_dims_for_inverse_encoder,  # Input includes spatial coordinates and deformation
            hidden_features=hidden_features,
            hidden_layers=hidden_layers,
            out_features=output_dim,
            outermost_linear=True,
            first_omega_0=OMEGA_0,
            hidden_omega_0=OMEGA_0,
            zero_init_last_layer=False,
            is_first= False
        )
    @property
    def device(self):
        """Get the device of the model."""
        return next(self.parameters()).device

    def extract_template_mesh(self, N: int = 512, truncate: bool = False, scale: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Extracts the template mesh points from the teacher model.

        Args:
            dim (int): Dimensionality of the mesh (default: 3).
            scale (float): Scale factor for the mesh (default: 2).
            N (int): Number of samples along each dimension (default: 512).

        Returns:
            torch.Tensor: Template mesh points.
        """
        #self.template_mesh = extract_mesh(self.teacher, scale=scale, N=N)[0].cuda()
        
        if 'age' in self.covariate_names:
            scale = 0.72 
        else:
            scale = scale if scale is not None else self.sampling_range
        verts, faces, _, _ = extract_mesh(self.teacher, scale=scale, N=N, truncate=truncate)
        self.template_mesh = verts.cuda()
        return verts, faces

    def validate_cfg(self, cfg: dict, required_keys: list[str]):
        """
        Validates the configuration dictionary.

        Args:
            cfg (dict): Configuration dictionary.
            required_keys (list[str]): List of required keys.

        Raises:
            ValueError: If any required key is missing.
        """
        for key in required_keys:
            if key not in cfg:
                raise ValueError(f"Missing required configuration key: {key}")

    def _freeze_teacher(self):
        """
        Freezes the parameters of the teacher model to prevent updates during training.
        """
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.teacher.eval()

    # def sample_random_input(self, spatial_dim: int = 3) -> dict:
    #     """
    #     Randomly samples spatial coordinates and covariates.

    #     Args:
    #         num_points (int): Number of points to sample.
    #         spatial_dim (int): Dimensionality of spatial coordinates (default: 3 for x, y, z).

    #     Returns:
    #         dict: Randomly sampled input containing 'coords' and 'covariates'.
    #     """
    #     coords = (torch.rand((self.batch_size, self.num_of_points, spatial_dim), device=self.device) * 2 - 1) * 3  # Sample points in [-1, 1] range
    #     covariates = {name: ((torch.rand((self.batch_size, self.num_of_points, 1), device=self.device) - 0.5)) * 4 for name in self.covariate_names}  # Generate random gt for all covariates
    #     return {"coords": coords, "covariates": covariates}

    # def sample_random_input(self) -> dict:
    #     """
    #     Randomly samples spatial coordinates from the template mesh bounding box,
    #     supporting both 2D and 3D shapes.

    #     Returns:
    #         dict: Randomly sampled input containing 'coords' and 'covariates'.
    #     """
    #     # Get bounding box from template mesh
    #     bbox_min = self.template_mesh.min(dim=0)[0]  # [2/3]
    #     bbox_max = self.template_mesh.max(dim=0)[0]  # [2/3]
    #     dim = bbox_min.shape[-1]  # 2 or 3
        
    #     # Sample points uniformly within bounding box
    #     coords = []
    #     for d in range(dim):
    #         # [B*N, 1]
    #         coord_d = torch.rand(self.batch_size * self.num_of_points, 1).cuda() * \
    #                  (bbox_max[d] - bbox_min[d]) + bbox_min[d]
    #         coords.append(coord_d)
    #     coords = torch.cat(coords, dim=-1)  # [B*N, 2/3]
    #     coords = coords.reshape(self.batch_size, self.num_of_points, -1)  # [B, N, 2/3]

    #     # Generate random covariates
    #     covariates = {name: ((torch.rand((self.batch_size, self.num_of_points, 1)).cuda() - 0.5) * 6) 
    #                  for name in self.covariate_names}

    #     return {"coords": coords, "covariates": covariates}


    def sample_random_input(self, bbox_scale: float = 1.5) -> dict:
        """
        Randomly samples spatial coordinates from the template mesh bounding box,
        supporting both 2D and 3D shapes.
        
        Args:
            bbox_scale: Scale factor to expand bounding box (default: 1.5)

        Returns:
            dict: Randomly sampled input containing 'coords' and 'covariates'.
        """
        # Get bounding box from template mesh
        bbox_min = self.template_mesh.min(dim=0)[0]  # [2/3]
        bbox_max = self.template_mesh.max(dim=0)[0]  # [2/3]
        dim = bbox_min.shape[-1]  # 2 or 3
        
        # Expand bounding box with scale factor
        bbox_center = (bbox_min + bbox_max) / 2  # [2/3]
        bbox_half_size = (bbox_max - bbox_min) / 2  # [2/3]
        bbox_min_expanded = bbox_center - bbox_half_size * bbox_scale  # [2/3]
        bbox_max_expanded = bbox_center + bbox_half_size * bbox_scale  # [2/3]
        
        # Sample points uniformly within expanded bounding box
        coords = []
        for d in range(dim):
            # [B*N, 1]
            coord_d = torch.rand(self.batch_size * self.num_of_points, 1).cuda() * \
                     (bbox_max_expanded[d] - bbox_min_expanded[d]) + bbox_min_expanded[d]
            coords.append(coord_d)
        coords = torch.cat(coords, dim=-1)  # [B*N, 2/3]
        coords = coords.reshape(self.batch_size, self.num_of_points, -1)  # [B, N, 2/3]

        # Generate random covariates
        covariates = {name: ((torch.rand((self.batch_size, self.num_of_points, 1)).cuda() - 0.5) * 6) 
                     for name in self.covariate_names}

        return {"coords": coords, "covariates": covariates}



    def forward(self, model_input) -> dict:
        """
        Forward pass for the inverse encoder using randomly sampled input.

        Args:
            num_points (int): Number of points to sample for the input.

        Returns:
            dict: Dictionary containing predicted latent age (t) and ground truth age (t).
        """
        # # Step 1: Randomly sample input
        model_input = self.sample_random_input()

        # Step 2: Use the teacher model to predict deformation
        teacher_output = self.teacher.forward_as_teacher(model_input)
        deformation = teacher_output['vec_fields']['overall']  # Extract overall deformation

        # Step 3: Concatenate spatial coordinates and deformation
        coords = model_input['coords'] # Ensure coords require gradient
        input_coords = self.encode_coord(coords.detach().clone().requires_grad_(True))
        input_deformation = self.encode_coord(deformation.detach().clone().requires_grad_(True))
        input_data = torch.cat([input_coords, input_deformation], dim=-1)

        # Step 4: Predict latent age (t)
        predicted_t = self.network(input_data) 

        # Step 5: Compute gradient calculations for orthogonality loss
        coords_for_grad = model_input['coords'].detach().clone().requires_grad_(True)
        predicted_t_for_grad = predicted_t.detach().clone().requires_grad_(True)
        gt_t = self.dict2array(model_input['covariates'])
        
        # Predict deformation using predicted age
        predicted_deformation = self.teacher.predict_deformations_as_teacher({
            "coords": coords_for_grad, 
            "covariates": {self.covariate_names[0]: predicted_t_for_grad}
        })['overall']
        
        g_i = torch.autograd.grad(
            outputs=predicted_deformation,
            inputs=predicted_t_for_grad,
            grad_outputs=torch.ones_like(coords_for_grad),
            create_graph=True
        )[0]

        model_output = {
            "model_in": model_input,
            "vec_fields": teacher_output['vec_fields'],
            "predicted_t": predicted_t,
            "ground_truth_t": self.dict2array(model_input['covariates']).detach().clone(),  # Randomly sampled ground truth age
            "g_i": g_i,  # Gradient of deformation w.r.t. latent age

        }
        #model_output.update(orthogonality_components)  # Add orthogonality components to output
        return model_output







    def inference(self, model_input) -> dict:
        """
        Forward pass for the inverse encoder using randomly sampled input.

        Args:
            num_points (int): Number of points to sample for the input.

        Returns:
            dict: Dictionary containing predicted latent age (t) and ground truth age (t).
        """
        # Step 1: Prepare Input
        #model_input = self.sample_random_input()
        model_input = self.prepare_model_input(
            model_input=model_input,
            pts_on_template=self.template_mesh[None, ...]
        )

        arr_t = model_input['covariates'][self.covariate_names[0]]
        predicted_variance  = self.teacher.predict_sigma(model_input['coords'], arr_t)
        fisher_information = 1 / predicted_variance


        # Step 2: Use the teacher model to predict deformation
        deformation = self.teacher.predict_teacher_deformation(model_input)

        # Step 3: Concatenate spatial coordinates and deformation
        coords = model_input['coords'] # Ensure coords require gradient
        input_coords = self.encode_coord(coords.detach().clone().requires_grad_(True))
        input_deformation = self.encode_coord(deformation.detach().clone().requires_grad_(True))
        input_data = torch.cat([input_coords, input_deformation], dim=-1)

        # Step 4: Predict latent age (t)
        predicted_t = self.network(input_data) 
        predicted_global_t = (predicted_t * fisher_information).sum(dim=[-2], keepdim=True) / fisher_information.sum(dim=[-2], keepdim=True)
        z_score = (predicted_t - arr_t) / predicted_variance.sqrt()
        predicted_t[z_score.abs() > 2] = predicted_global_t#[z_score.abs() > 3].unsqueeze(-1)


        dict_pred = self.array2dict(predicted_t)
        model_output = {
            "predicted_t": predicted_t,
            "ground_truth_t": self.dict2array(model_input['covariates']).detach().clone(),  # Randomly sampled ground truth age
            "predicted_covariates": dict_pred
        }
        return model_output


    def inference_global_time(self, model_input: dict) -> dict:
        """
        Forward pass for the inverse encoder using randomly sampled input.

        Args:
            num_points (int): Number of points to sample for the input.

        Returns:
            dict: Dictionary containing predicted latent age (t) and ground truth age (t).
        """
        # Step 1: Prepare Input
        #model_input = self.sample_random_input()

        model_input = self.prepare_model_input(
            model_input=model_input,
            pts_on_template=self.template_mesh[None, ...]
        )

        arr_t = model_input['covariates'][self.covariate_names[0]]
        predicted_variance  = self.teacher.predict_sigma(model_input['coords'], arr_t)
        fisher_information = 1 / predicted_variance


        # Step 2: Use the teacher model to predict deformation
        deformation = self.teacher.predict_teacher_deformation(model_input)

        # Step 3: Concatenate spatial coordinates and deformation
        coords = model_input['coords'] # Ensure coords require gradient
        input_coords = self.encode_coord(coords.detach().clone().requires_grad_(True))
        input_deformation = self.encode_coord(deformation.detach().clone().requires_grad_(True))
        input_data = torch.cat([input_coords, input_deformation], dim=-1)

        # Step 4: Predict latent age (t)
        predicted_t = self.network(input_data) 


        predicted_global_t = (predicted_t * fisher_information).sum(dim=[-2]) / fisher_information.sum(dim=[-2])
        
        # Step 5: Compute global variance using top-k fisher information
        predicted_global_sigma = self.teacher.predict_global_sigma(arr_t).sqrt()

        dict_pred = self.array2dict(predicted_t)
        dict_global_pred = self.array2dict(predicted_global_t)
        model_output = {
            "predicted_t": predicted_t,
            "predicted_global_t": predicted_global_t,
            "predicted_global_sigma": predicted_global_sigma,
            "ground_truth_t": self.dict2array(model_input['covariates']).detach().clone(),  # Randomly sampled ground truth age
            "predicted_covariates": dict_pred,
            "predicted_global_covariates": dict_global_pred
        }

        
        return model_output





    def dict2array(self, dict_covariates: dict) -> torch.Tensor:
        """
        Convert covariate dictionary to tensor array.
        
        Args:
            dict_covariates: Dictionary of covariates {name: tensor}
        
        Returns:
            torch.Tensor: Concatenated covariate tensor
        """
        list_covariates = []
        for ith_cov_name in self.covariate_names:
            current_covariate_value = dict_covariates[ith_cov_name]
            list_covariates.append(current_covariate_value)
        arr_covariates = torch.cat(list_covariates, dim=-1)
        return arr_covariates
    
    def array2dict(self, arr_covariates: torch.Tensor) -> dict:
        """
        Convert covariate tensor array to dictionary.
        
        Args:
            arr_covariates: Concatenated covariate tensor [B, N, num_covariates]
        
        Returns:
            dict: Dictionary of covariates {name: tensor}
        """
        dict_covariates = {}
        for idx, cov_name in enumerate(self.covariate_names):
            # Extract the idx-th covariate from the last dimension
            dict_covariates[cov_name] = arr_covariates[..., [idx]].detach().clone()
        return dict_covariates
    
    def predict_covariates(self, coords: torch.Tensor, deformation: torch.Tensor) -> torch.Tensor:
        """
        Predict age directly from input coordinates and deformation.
        
        Args:
            coords (torch.Tensor): Template point coordinates [B, N, 3] (xyz)
            deformation (torch.Tensor): Deformation vectors [B, N, 3]
            
        Returns:
            torch.Tensor: Predicted age [B, 1]
        """

        # Apply positional encoding to coordinates and deformation
        encoded_coords = self.encode_coord(coords)
        encoded_deformation = self.encode_coord(deformation)
        
        # Concatenate encoded coordinates and deformation
        input_data = torch.cat([encoded_coords, encoded_deformation], dim=-1)
        
        # Run the network to get predicted age
        predicted_covariates = self.network(input_data)

        dict_output_covariates = {}
        for ith_cov in self.covariate_names:
            dict_output_covariates[ith_cov] = predicted_covariates[..., [self.dict_cov_idx[ith_cov]]]

        return dict_output_covariates

    def compute_orthogonality_constraint(self, coords, random_deformation, pred_deformation, predicted_t):
        """
        Compute orthogonality constraint: (random_deform - pred_deform) ⊥ ∂pred_deform/∂t
        
        Args:
            coords (torch.Tensor): Spatial coordinates [B, N, 3]
            random_deformation (torch.Tensor): Random sampled deformation [B, N, 3]
            pred_deformation (torch.Tensor): Teacher predicted deformation [B, N, 3]
            predicted_t (torch.Tensor): Predicted age [B, N, 1]
            
        Returns:
            dict: Orthogonality constraint components
        """
        # Step 1: Compute residual (random_deform - pred_deform)
        residual = random_deformation - pred_deformation  # [B, N, 3]
        
        # Step 2: Compute gradient of pred_deformation w.r.t. predicted_t using finite difference
        h = 1e-4
        
        # Predicted deformation at predicted_t + h
        pred_deform_plus_h = self.teacher.predict_deformations_as_teacher({
            "coords": coords,
            "covariates": {"age": predicted_t + h}
        })['overall']
        
        # Predicted deformation at predicted_t - h
        pred_deform_minus_h = self.teacher.predict_deformations_as_teacher({
            "coords": coords,
            "covariates": {"age": predicted_t - h}
        })['overall']
        
        # Gradient: ∂pred_deform/∂t
        pred_deform_gradient = (pred_deform_plus_h - pred_deform_minus_h) / (2 * h)  # [B, N, 3]
        
        # Step 3: Compute orthogonality constraint
        # We want residual ⊥ pred_deform_gradient, i.e., residual · pred_deform_gradient = 0
        dot_product = torch.sum(residual * pred_deform_gradient, dim=-1, keepdim=True)  # [B, N, 1]
        orthogonality_loss = torch.mean(dot_product ** 2)
        
        # Step 4: Additional metrics for monitoring
        residual_norm = torch.mean(torch.norm(residual, dim=-1))
        gradient_norm = torch.mean(torch.norm(pred_deform_gradient, dim=-1))
        
        return {
            "orthogonality_loss": orthogonality_loss,
            "residual": residual,  # random_deform - pred_deform
            "pred_deform_gradient": pred_deform_gradient,  # ∂pred_deform/∂t
            "dot_product": dot_product,  # residual · gradient
            "residual_norm": residual_norm,
            "gradient_norm": gradient_norm
        }