from torch import nn
from typing import Any
from umfavi.loglikelihoods.base import BaseLogLikelihood
from umfavi.priors import kl_div_std_normal
from umfavi.encoder.reward_encoder import BaseRewardEncoder
from umfavi.regularizer.td_error import td_error_regularizer
from umfavi.types import FeedbackType
from umfavi.types import DataKey


class MultiFeedbackTypeModel(nn.Module):

    def __init__(
        self,
        encoder: BaseRewardEncoder,
        q_model: nn.Module,
        decoders: dict[FeedbackType, BaseLogLikelihood],
        actions_discrete: bool = True
    ):
        super().__init__()
        self.encoder = encoder
        self.q_model = q_model
        self.decoders = nn.ModuleDict(decoders)
        self.actions_discrete = actions_discrete
    
    def forward(self, **kwargs) -> Any:

        # Unpack variables
        obs = kwargs[DataKey.OBS]
        next_obs = kwargs[DataKey.NEXT_OBS]
        action_feats = kwargs[DataKey.ACT_FEATS]
        next_action_feats = kwargs[DataKey.NEXT_ACT_FEATS] 
        valid = kwargs[DataKey.VALID]
        terminal = kwargs[DataKey.TERMINAL]
        acts_curr = kwargs[DataKey.ACTS].long()
        acts_next = kwargs[DataKey.NEXT_ACTS].long()
        gamma = kwargs[DataKey.GAMMA][0]
        fb_type = kwargs[DataKey.FEEDBACK_TYPE][0]

        r_mu, r_log_var = self.encoder(obs, action_feats, next_obs)
        r_mu = r_mu.squeeze(-1)
        r_log_var = r_log_var.squeeze(-1)
        r_samples = self.encoder.sample(r_mu, r_log_var)
        # Always mask terminal states from KL to prevent the prior from pushing
        # terminal rewards toward zero (which distorts the learned reward structure)
        kl_div = kl_div_std_normal(r_mu, r_log_var, valid, terminal=terminal)

        # Route to appropriate head
        head = self.decoders[fb_type]  # same feedback type for all samples per batch
        
        # TD-error constraint
        if self.actions_discrete:
            q_curr = self.q_model(obs)
            q_next = self.q_model(next_obs)
        else:
            q_curr = self.q_model(obs, action_feats)
            q_next = self.q_model(next_obs, next_action_feats)
        td_error = td_error_regularizer(
            acts_curr=acts_curr,
            acts_next=acts_next,
            q_curr=q_curr,
            q_next=q_next,
            r_mu=r_mu,
            r_log_var=r_log_var,
            gamma=gamma,
            valid=valid,
            terminal=terminal,
            actions_discrete=self.actions_discrete
        )

        metrics = {
            "q_value_max": q_curr.max(),
            "q_value_min": q_curr.min(),
        }

        nll = 0
        if fb_type == FeedbackType.PREFERENCE:
            prefs = kwargs[DataKey.PREFERENCE]
            beta = kwargs[DataKey.RATIONALITY][0]
            nll = head(r_samples, prefs, beta, valid)
        elif fb_type == FeedbackType.DEMONSTRATION:
            nll = head(acts_curr, q_curr, valid)
        elif fb_type == FeedbackType.RATING:
            ratings = kwargs[DataKey.RATING]
            nll = head(r_samples, ratings, valid)
        elif fb_type == FeedbackType.RANKING:
            ranks = kwargs[DataKey.RANKING]
            beta = kwargs[DataKey.RATIONALITY][0]
            nll = head(r_samples, ranks, beta, valid)
        elif fb_type == FeedbackType.STOP:
            stop_times = kwargs[DataKey.STOP_TIME]
            lambd = kwargs[DataKey.LAMBDA][0]
            regret_discount = kwargs[DataKey.REGRET_DISCOUNT][0]
            nll = head(q_curr, acts_curr, stop_times, lambd, regret_discount, valid)
        else:
            raise ValueError(f"Invalid feedback type: {fb_type.value}")

        # Create final output
        output = {
            "negative_log_likelihood": nll,
            "kl_divergence": kl_div,
            "td_error": td_error,
        }
        output.update(metrics)
        return output
