import sys
sys.path.append("/playpen-raid/jyn/PRISM")
import torch
from torch import nn
import logging
from model.base import modules, HyperNetwork, Siren
from model.loss import embedding_loss_naixr_v4
from .base_net import BaseNet
from typing import Dict, Optional, List
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OMEGA_0 = 30  # Constant for ResidualSiren initialization


class Teacher(BaseNet):
    """
    Teacher class for modeling deformation fields and SDF using a hypernetwork-based architecture.

    Attributes:
        forward_deform_codes (nn.Embedding): Embedding for forward deformation latent codes.
        backward_deform_codes (nn.Embedding): Embedding for backward deformation latent codes.
        hyper_decoder_forward (HyperNetwork): Hypernetwork for forward deformation.
        hyper_decoder_backward (HyperNetwork): Hypernetwork for backward deformation.
        template (nn.Module): SDF template network.
    """
    def __init__(self, cfg):
        """
        Initializes the Teacher model.

        Args:
            cfg (dict): Configuration dictionary containing model parameters.
        """
        # Initialize BaseNet first to handle positional encoding
        super().__init__(cfg)
        logger.info("Initializing Teacher...")

        # Validate configuration (remove PosEnc from required keys since BaseNet handles it)
        # Teacher model doesn't need CovariateNames - it's only needed for Student model
        required_keys = [
            "InFeatures", "HiddenFeatures", "HiddenLayers",
            "HyperHiddenLayers", "HyperHiddenFeatures", "NumInstances",
            "OutFeatures", "Device", "Backbone", "CodeLength"
        ]
        self.validate_cfg(cfg, required_keys)

        # Extract configuration parameters
        self.original_in_dim = cfg.get("InFeatures", 3)
        self.hidden_features = cfg.get("HiddenFeatures", 512)
        self.hidden_layers = cfg.get("HiddenLayers", 6)
        self.hyper_hidden_layers = cfg.get("HyperHiddenLayers", 3)
        self.hyper_hidden_features = cfg.get("HyperHiddenFeatures", 512)
        self.num_instances = cfg.get("NumInstances", 0)
        self.out_features = cfg.get("OutFeatures", 1)
        self.backbone = cfg.get("Backbone", "siren")
        self.latent_size = cfg.get("CodeLength", 256)

        # Latent embeddings
        self.forward_deform_codes = nn.Embedding(self.num_instances, self.latent_size).cuda()
        nn.init.normal_(self.forward_deform_codes.weight, mean=0, std=0.01)

        self.backward_deform_codes = nn.Embedding(self.num_instances, self.latent_size).cuda()
        nn.init.normal_(self.backward_deform_codes.weight, mean=0, std=0.01)

        # Forward deformation decoder
        self.decoder_forward = modules.MetaSingleBVPNet(
            type="sin30",
            mode="sin30",
            hidden_features=128,
            num_hidden_layers=3,
            in_features=self.in_dim,  # Use BaseNet's calculated input dimension
            out_features=self.original_in_dim  # Output deformation (x, y, z)
        )
        self.hyper_decoder_forward = HyperNetwork(
            hyper_in_features=self.latent_size,
            hyper_hidden_layers=self.hyper_hidden_layers,
            hyper_hidden_features=self.hyper_hidden_features,
            hypo_module=self.decoder_forward
        )

        # Backward deformation decoder
        self.decoder_backward = modules.MetaSingleBVPNet(
            type="sin30",
            mode="sin30",
            hidden_features=128,
            num_hidden_layers=3,
            in_features=self.in_dim,  # Use BaseNet's calculated input dimension
            out_features=self.original_in_dim  # Output deformation (x, y, z)
        )
        self.hyper_decoder_backward = HyperNetwork(
            hyper_in_features=self.latent_size,
            hyper_hidden_layers=self.hyper_hidden_layers,
            hyper_hidden_features=self.hyper_hidden_features,
            hypo_module=self.decoder_backward
        )

        # SDF template network
        if self.backbone == "mlp":
            self.template = modules.BaseFCBlock(
                in_features=self.in_dim,  # Use BaseNet's calculated input dimension
                out_features=self.out_features,
                hidden_layers=self.hidden_layers,
                hidden_features=self.hidden_features,
                outermost_linear=True,
                nonlinearity="relu"
            )
        elif self.backbone == "siren":
            self.template = Siren(
                in_features=self.in_dim,  # Use BaseNet's calculated input dimension
                hidden_features=self.hidden_features,
                hidden_layers=self.hidden_layers,
                out_features=self.out_features,
                is_first=True,
                zero_init_last_layer=True,
                outermost_linear=True,
                first_omega_0=OMEGA_0,
                hidden_omega_0=OMEGA_0,
            )

    @property
    def device(self):
        """Get the device of the model."""
        return next(self.parameters()).device

    def validate_cfg(self, cfg: dict, required_keys: List[str]):
        """
        Validates the configuration dictionary.

        Args:
            cfg (dict): Configuration dictionary.
            required_keys (list[str]): List of required keys.

        Raises:
            ValueError: If any required key is missing.
        """
        for key in required_keys:
            if key not in cfg:
                raise ValueError(f"Missing required configuration key: {key}")

    def calculate_forward_displacement(self, model_input: torch.Tensor, current_idxes: torch.Tensor):
        embedding = self.forward_deform_codes(current_idxes)
        # get network weights for Deform-net using Hyper-net
        coords_encoded = self.encode_coord(model_input)

        hypo_params = self.hyper_decoder_forward(embedding)
        deformation = self.decoder_forward(coords_encoded, params=hypo_params)['model_out']
        return deformation


    def forward_deformation(self, model_input: dict):
        current_idxes = model_input['teacher_idx'].long()
        coords_init = model_input['coords']
        '''
        forward
        '''
        forward_deformation = self.calculate_forward_displacement(coords_init, current_idxes)
        pts_on_template = coords_init + forward_deformation
        '''
        query template
        '''
        output_sdf = self.template(self.encode_coord(pts_on_template))
        return output_sdf, forward_deformation



    def calculate_backward_displacement(self, model_input: torch.Tensor, current_idxes: torch.Tensor):
        embedding = self.backward_deform_codes(current_idxes)
        # get network weights for Deform-net using Hyper-net
        coords_encoded = self.encode_coord(model_input)

        hypo_params = self.hyper_decoder_backward(embedding)
        deformation = self.decoder_backward(coords_encoded, params=hypo_params)['model_out']
        return deformation


    def backward_deformation(self, model_input: dict):
        current_idxes = model_input['teacher_idx'].long()
        coords_init = model_input['coords']
        '''
        forward
        '''
        backward_deformation = self.calculate_backward_displacement(coords_init, current_idxes)
        return  backward_deformation


    def unpack_grad_deform_gradicon(self, model_input: torch.Tensor, deformation: torch.Tensor):
        """
        Calculate gradient of deformation field for grad_icon loss (same as Teacher).
        
        This computes ∇deformation - I for the grad_icon consistency loss.
        
        Args:
            model_input: Input coordinates [B, N, 3] (requires_grad=True)
            deformation: Deformation field [B, N, 3]
            
        Returns:
            torch.Tensor: Gradient of deformation - I
        """
        if model_input.shape[-1] == 3:
            # 3D case
            x = model_input  # input coordinates
            u = deformation[:, :, 0]
            v = deformation[:, :, 1]
            w = deformation[:, :, 2]

            grad_outputs = torch.ones_like(u)
            grad_u =torch.autograd.grad(u, [x], grad_outputs=grad_outputs, create_graph=True)[0]
            grad_v = torch.autograd.grad(v, [x], grad_outputs=grad_outputs, create_graph=True)[0]
            grad_w = torch.autograd.grad(w, [x], grad_outputs=grad_outputs, create_graph=True)[0]
            grad_deform = torch.stack([grad_u, grad_v, grad_w], dim=2)  # gradient of deformation wrt. input position
            grad_deform = grad_deform - torch.eye(grad_deform.shape[-1], device=grad_deform.device)
        elif model_input.shape[-1] == 2:
            # 2D case
            x = model_input  # input coordinates
            u = deformation[:, :, 0]
            v = deformation[:, :, 1]

            grad_outputs = torch.ones_like(u)
            grad_u = torch.autograd.grad(u, [x], grad_outputs=grad_outputs, create_graph=True)[0]
            grad_v = torch.autograd.grad(v, [x], grad_outputs=grad_outputs, create_graph=True)[0]
            grad_deform = torch.stack([grad_u, grad_v], dim=2)  # gradient of deformation wrt. input position
            grad_deform = grad_deform - torch.eye(grad_deform.shape[-1], device=grad_deform.device)
        
        return grad_deform

    def unpack_icon(self, model_input: torch.Tensor, deformation: torch.Tensor, inv_deformation: torch.Tensor):
        """
        Unpack ICON loss components (same as Teacher).
        
        ICON loss enforces cycle consistency: coords → deformed → back to coords
        
        Args:
            model_input: Original coordinates
            deformation: Forward deformation
            inv_deformation: Backward deformation
            
        Returns:
            tuple: (cycled_points, delta) where delta should be close to zero
        """
        delta = deformation + inv_deformation
        model_output = model_input + deformation + inv_deformation
        return model_output, delta


    def forward(self, model_input: dict):

        coords_init = model_input['coords'].requires_grad_(True)

        '''
        forward
        '''
        output_sdf, forward_deformation = self.forward_deformation(model_input)
        dict_vf = {}
        dict_vf['overall'] = forward_deformation
        pts_on_template = coords_init + forward_deformation


        '''
        backward from ind
        '''
        model_input_inv = model_input.copy()
        model_input_inv['coords'] = pts_on_template
        backward_deformation = self.backward_deformation(model_input_inv)


        dict_vf['overall_inv'] = backward_deformation


        cycle_pts_on_src, cycle_delta = self.unpack_icon(coords_init, forward_deformation, backward_deformation)
        grad_icon = self.unpack_grad_deform_gradicon(coords_init, cycle_pts_on_src)


        # get map
        model_output = {'model_in': coords_init,
                        'all_input': coords_init,
                        'model_out': output_sdf,
                        'vec_fields': dict_vf,
                        'template': self.template_sdf(coords_init),
                        "pts_on_template": pts_on_template,
                        "icon": cycle_delta,
                        "gradicon": grad_icon,
                        }

        return model_output  # , coords

    def inference(self, model_input):


        coords_init = model_input['coords'].requires_grad_(True)

        '''
        forward
        '''
        output_sdf, forward_deformation = self.forward_deformation(model_input)
        dict_vf = {}
        dict_vf['overall'] = forward_deformation
        pts_on_template = coords_init + forward_deformation


        '''
        backward from ind
        '''
        model_input_inv = model_input.copy()
        model_input_inv['coords'] = pts_on_template
        backward_deformation = self.backward_deformation(model_input_inv)


        dict_vf['overall_inv'] = backward_deformation



        # get map
        model_output = {'model_in': coords_init,
                        'all_input': coords_init,
                        'model_out': output_sdf,
                        'vec_fields': dict_vf,
                        'template': self.template_sdf(coords_init),
                        "pts_on_template": pts_on_template,
                        }

        return model_output  # , coords


    def template_sdf(self, coords):
        # Use BaseNet's positional encoding if available
        coords_encoded = self.encode_coord(coords)
        return self.template(coords_encoded)


    def embedding(self, coords, var_codes, gt):

        coords_init = coords.requires_grad_(True)

        embeddings = var_codes(torch.Tensor([0]).cuda().long())

        '''
        use embedding as prior
        '''
        dict_vf = {}

        vf, inv_vf = self.calculate_displacement(coords_init, embeddings)
        dict_vf['overall'] = vf

        pts_on_template = vf + coords
        output = self.template(self.encode_coord(vf + coords))

        # get map
        model_output = {'model_in': coords,
                        'all_input': coords,
                        'model_out': output,
                        'vec_fields': dict_vf,
                        'embedding': embeddings,
                        'template': self.template_sdf(coords),
                        "pts_on_template": pts_on_template,
                        }
        losses = embedding_loss_naixr_v4(model_output, gt, dict_losses={'whether_kld': False})
        return model_output, losses

    def calculate_inv_displacement(self, model_input):
        '''
        backward from ind
        '''
        # if 'teacher_idx' in model_input:
        #     model_input['idx'] = model_input['teacher_idx']
        backward_deformation = self.backward_deformation(model_input)
        return backward_deformation

    def inv_transform(self, model_input):
        dict_vf = {}
        vf = self.calculate_inv_displacement(model_input)
        dict_vf['overall'] = vf
        transformed_p = vf + model_input['coords']
        return transformed_p