import dataclasses
import logging
from typing import Literal, List, Tuple

import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override

from openpi.models import model as _model
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
import openpi.models.value as _value
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils

logger = logging.getLogger("openpi")


def make_attn_mask(input_mask, mask_ar=None, block_sizes=None):
    """Create attention mask based on input mask and optional block structure.

    Args:
      input_mask: bool[B, N] true if its part of the input, false if padding.
      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
        it and false where it shares the same attention mask as the previous token.
      block_sizes: Optional list of block sizes. If provided, implements custom 
        block-based attention where each block can only attend to specific blocks
        according to predefined rules.
    """
    if block_sizes is None:
        # Original implementation for standard causal attention
        mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
        cumsum = jnp.cumsum(mask_ar, axis=1)
        attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
    else:
        # Custom block-based attention
        batch_size = input_mask.shape[0]
        seq_len = input_mask.shape[1]
        
        # Start with a mask where nothing can attend to anything
        attn_mask = jnp.zeros((batch_size, seq_len, seq_len), dtype=bool)
        
        # Calculate starting indices for each block
        start_indices = [0]
        for size in block_sizes[:-1]:
            start_indices.append(start_indices[-1] + size)
        
        # Set attention patterns for each block
        # Block 0 (prefix): can only attend within itself
        prefix_start, prefix_end = start_indices[0], start_indices[1]
        attn_mask = attn_mask.at[:, prefix_start:prefix_end, prefix_start:prefix_end].set(True)
        
        # Block 1 (policy): can attend to prefix and within itself
        if len(block_sizes) > 1:
            suffix_start, suffix_end = start_indices[1], start_indices[2] if len(block_sizes) > 2 else seq_len
            attn_mask = attn_mask.at[:, suffix_start:suffix_end, prefix_start:prefix_end].set(True)  # attend to prefix
            attn_mask = attn_mask.at[:, suffix_start:suffix_end, suffix_start:suffix_end].set(True)  # attend to self
            attn_mask = attn_mask.at[:, suffix_start:suffix_start+1, suffix_start+1:suffix_end].set(False)  # attend to self
        
        # Block 2 (value): can attend to prefix and within itself
        if len(block_sizes) > 2:
            value_start, value_end = start_indices[2], seq_len
            attn_mask = attn_mask.at[:, value_start:value_end, prefix_start:prefix_end].set(True)  # attend to prefix
            attn_mask = attn_mask.at[:, value_start:value_end, value_start:value_end].set(True)  # attend to self
            attn_mask = attn_mask.at[:, value_start:value_start+1, value_start+1:value_end].set(False)  # attend to self
    
    # Apply the input mask to ensure padding tokens aren't attended to
    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
    return jnp.logical_and(attn_mask, valid_mask)


@at.typecheck
def posemb_sincos(
    pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
    """Computes sine-cosine positional embedding vectors for scalar positions."""
    if embedding_dim % 2 != 0:
        raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")

    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
    period = min_period * (max_period / min_period) ** fraction
    sinusoid_input = jnp.einsum(
        "i,j->ij",
        pos,
        1.0 / period * 2 * jnp.pi,
        precision=jax.lax.Precision.HIGHEST,
    )
    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)

    
@dataclasses.dataclass(frozen=True)
class Pi0ValueConfig(_model.BaseModelConfig):
    dtype: str = "bfloat16"
    paligemma_variant: _gemma.Variant = "gemma_value"
    network_type: Literal["hierarchical_q", "q"] = "q"
    expectile: float = 0.8
    beta: float = 1.0
    # Set the model specific defaults.
    action_dim: int = 32
    action_horizon: int = 50
    hierarchical_actions: list = dataclasses.field(default_factory=lambda: [[7,8,9], [17,18,19,20], [0,1,2,3,4,5,6,10,11,12,13,14,15,16]])
    max_token_len: int = 72
    gamma: float = 1.0
    method: Literal["expectile", "exponential"] = "exponential"

    @property
    @override
    def model_type(self) -> _model.ModelType:
        return _model.ModelType.PI0

    @override
    def create(self, rng: at.KeyArrayLike) -> "Pi0Value":
        return Pi0Value(self, rngs=nnx.Rngs(rng))

    @override
    def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)

        with at.disable_typechecking():
            observation_spec = _model.Observation(
                images={
                    "base_0_rgb": image_spec,
                    "left_wrist_0_rgb": image_spec,
                    "right_wrist_0_rgb": image_spec,
                },
                image_masks={
                    "base_0_rgb": image_mask_spec,
                    "left_wrist_0_rgb": image_mask_spec,
                    "right_wrist_0_rgb": image_mask_spec,
                },
                state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
                tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
                tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
            )
        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)

        return observation_spec, action_spec

    def get_freeze_filter(self) -> nnx.filterlib.Filter:
        """Returns the freeze filter based on the model config."""
        filters = []
        has_lora = False
        gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
        action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
        if "lora" in self.paligemma_variant:
            filters.append(
                gemma_params_filter,
            )
            if "lora" not in self.action_expert_variant:
                # If only freeze gemma params, exclude action expert params.
                filters.append(
                    nnx.Not(action_expert_params_filter),
                )
            has_lora = True
        elif "lora" in self.action_expert_variant:
            filters.append(
                action_expert_params_filter,
            )
            has_lora = True

        if has_lora:
            # If any lora is used, exclude all lora params.
            filters.append(
                nnx.Not(nnx_utils.PathRegex(".*lora.*")),
            )
        if not filters:
            return nnx.Nothing
        return nnx.All(*filters)


class Pi0Value(_model.BaseModel):
    def __init__(self, config: Pi0ValueConfig, rngs: nnx.Rngs):
        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
        value_config = _gemma.get_config(config.paligemma_variant)

        self.expectile = config.expectile
        self.gamma = config.gamma
        self.beta = config.beta
        self.method = config.method

        # Initialize main model components
        # TODO: rewrite gemma in NNX. For now, use bridge.
        llm = nnx_bridge.ToNNX(
            _gemma.Module(
                configs=[value_config],
                embed_dtype=config.dtype,
            )
        )
        llm.lazy_init(rngs=rngs, method="init")
        img = nnx_bridge.ToNNX(
            _siglip.Module(
                num_classes=value_config.width,
                variant="M/14",
                pool_type="none",
                scan=True,
                dtype_mm=config.dtype,
            )
        )
        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
        self.PaliGemma = nnx.Dict(llm=llm, img=img)
        
        self.state_proj = nnx.Linear(config.action_dim, value_config.width, rngs=rngs)
        self.value_head = _value.VModel(value_config.width, rngs=rngs)
        self.network_type = config.network_type
        if self.network_type == "hierarchical_q":
            self.hierarchical_actions = config.hierarchical_actions
            hierarchical_action_dim = [len(actions) * self.action_horizon for actions in config.hierarchical_actions]
            self.q_head = _value.HierarchicalQModel(value_config.width, hierarchical_action_dim, rngs=rngs)
        else:
            self.q_head = _value.QModel(value_config.width, self.action_dim*self.action_horizon, rngs=rngs)
        
    @at.typecheck
    def embed_prefix(
        self, obs: _model.Observation, method: str = "policy"
    ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
        """Embed prefix tokens using specified method.
        
        Args:
            obs: Observation containing images and tokenized prompts
            method: Which model to use for embedding ("policy", "q", "value", "reward")
                   - "policy": Uses PaliGemma (images + language)
                   - "q": Uses q.llm (language only)  
                   - "value": Uses value.llm (language only)
                   - "reward": Uses reward.llm (language only)
        """
        input_mask = []
        ar_mask = []
        tokens = []
        
        # embed images
        for name in obs.images:
            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)

            tokens.append(image_tokens)
            input_mask.append(
                einops.repeat(
                    obs.image_masks[name],
                    "b -> b s",
                    s=image_tokens.shape[1],
                )
            )
            # image tokens attend to each other
            ar_mask += [False] * image_tokens.shape[1]

        # add language (aka tokenized inputs) for all methods
        if obs.tokenized_prompt is not None:
            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")

            tokens.append(tokenized_inputs)
            input_mask.append(obs.tokenized_prompt_mask)
            # full attention between image and language inputs
            ar_mask += [False] * tokenized_inputs.shape[1]

        # state_token = self.state_proj(obs.state)[:, None, :]
        # tokens.append(state_token)
        # input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
        # # image/language inputs do not attend to state or actions
        # ar_mask += [False]
            
        tokens = jnp.concatenate(tokens, axis=1)
        input_mask = jnp.concatenate(input_mask, axis=1)
        ar_mask = jnp.array(ar_mask)

        return tokens, input_mask, ar_mask

    def get_q(self, observation: _model.Observation, actions: _model.Actions, prefix_output: at.Float[at.Array, "b s emb"] | None = None) -> at.Float[at.Array, "b"] | List[at.Float[at.Array, "b"]]:
        if prefix_output is None:
            prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
            prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
            prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
            prefix_output, _ = self.PaliGemma.llm([prefix_tokens], mask=prefix_attn_mask, positions=prefix_positions)
        embed = prefix_output[0][:, -1, :]
        if self.network_type == "hierarchical_q":
            hierarchy_actions = [actions[:, :, actions_idx].reshape(-1, len(actions_idx)*self.action_horizon) for actions_idx in self.hierarchical_actions]
            q_values = self.q_head(embed, hierarchy_actions)
        else:
            actions = actions.reshape(-1, self.action_dim*self.action_horizon)
            q_values = self.q_head(embed, actions)
        return q_values
    
    def get_value(self, observation: _model.Observation, prefix_output: at.Float[at.Array, "b s emb"] | None = None) -> at.Float[at.Array, "b"]:
        if prefix_output is None:
            prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
            prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
            prefix_positions = jnp.cumsum(prefix_mask, axis=1) - 1
            prefix_output, _ = self.PaliGemma.llm([prefix_tokens], mask=prefix_attn_mask, positions=prefix_positions)
        embed = prefix_output[0][:, -1, :]
        v_values = self.value_head(embed)
        return v_values
    
    def get_advantages(
            self, 
            observation: _model.Observation, 
            actions: _model.Actions, 
            prefix_output: at.Float[at.Array, "b s emb"] | None = None
        ) -> at.Float[at.Array, "b"] | List[at.Float[at.Array, "b"]]:
        # compute Q
        q_values = self.get_q(observation=observation, actions=actions, prefix_output=prefix_output)
        
        # compute V
        v_values = self.get_value(observation=observation, prefix_output=prefix_output)

        # compute advantages
        if self.network_type == "hierarchical_q":
            # For hierarchical Q, return list of advantages for each action group
            advantages = [jax.lax.stop_gradient(q - v_values) for q in q_values]
        else:
            # For standard Q, return single advantage
            advantages = jax.lax.stop_gradient(q_values - v_values)

        return advantages

    @override
    def compute_loss(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        actions: _model.Actions,
        next_observation: _model.Observation,
        rewards: _model.Rewards,
        terminals: _model.Terminals,
        train: bool = False
    ) -> tuple[at.Float[at.Array, ""], dict]:
        """Compute IQL loss using expectile regression for value function only.
        
        This implementation focuses only on learning the value function through
        expectile regression without updating the policy.
        
        Args:
            rng: Random key for sampling.
            observation: Current observation.
            actions: Actions taken in current observation.
            next_observation: Next observation after taking actions.
            
        Returns:
            Tuple of (total_loss, info_dict) where info_dict contains detailed loss components.
        """ 
        observation = _model.preprocess_observation(rng, observation, train=train)
        next_observation = _model.preprocess_observation(rng, next_observation, train=train)

        idx = jnp.argmax(jnp.abs(rewards), axis=-1)
        reward = jnp.take_along_axis(rewards, idx[..., None], axis=-1)

        # compute prefix tokens and kv cache
        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
        positions = jnp.cumsum(prefix_mask, axis=1) - 1
        prefix_output, _ = self.PaliGemma.llm([prefix_tokens], mask=prefix_attn_mask, positions=positions)
        
        # Get current Q for current state-action pair
        current_q = self.get_q(observation, actions, prefix_output=prefix_output)
        # Compute value for current state using prefix kv_cache
        current_value = self.get_value(observation, prefix_output=prefix_output)

        # compute next value
        next_value = self.get_value(next_observation)

        # Check if trajectory should terminate based on reward threshold
        # If reward > 5 or < -1, set next_value to 0 (terminal state)
        is_terminal = jnp.sum(terminals, axis=-1)
        next_value_terminal = jnp.where(is_terminal, 0.0, next_value)

        # compute target value
        target_value = reward + self.gamma * next_value_terminal
        # target_value = reward + self.gamma * next_value

        # compute td loss
        if self.network_type == "hierarchical_q":
            # For hierarchical Q, compute TD loss for each Q value and sum them
            td_losses = []
            for q_val in current_q:
                td_losses.append(jnp.mean(jnp.square(jax.lax.stop_gradient(target_value) - q_val)))
            td_loss = jnp.mean(jnp.array(td_losses))
        else:
            # For standard Q, compute single TD loss
            td_loss = jnp.mean(jnp.square(jax.lax.stop_gradient(target_value) - current_q))

        # compute expectile
        if self.method == 'expectile':
            if self.network_type == "hierarchical_q":
                # For hierarchical Q, compute value loss for each Q value and sum them
                value_losses = []
                for q_val in current_q:
                    diff = jax.lax.stop_gradient(q_val) - current_value
                    value_losses.append(jnp.mean(
                        jnp.where(
                            diff > 0,
                            self.expectile * jnp.square(diff),
                            (1 - self.expectile) * jnp.square(diff)
                        )
                    ))
                value_loss = jnp.mean(jnp.array(value_losses))
            else:
                # For standard Q, compute single value loss
                diff = jax.lax.stop_gradient(current_q) - current_value
                value_loss = jnp.mean(
                    jnp.where(
                        diff > 0,
                        self.expectile * jnp.square(diff),
                        (1 - self.expectile) * jnp.square(diff)
                    )
                )
        elif self.method == 'exponential':
            if self.network_type == "hierarchical_q":
                # For hierarchical Q, compute value loss for each Q value and sum them
                value_losses = []
                for q_val in current_q:
                    diff = jax.lax.stop_gradient(q_val) - current_value
                    exp_diff = jax.lax.stop_gradient(jnp.exp(diff * self.beta))
                    value_losses.append(jnp.mean((exp_diff - 1) * (diff)))
                value_loss = jnp.mean(jnp.array(value_losses))
            else:
                # For standard Q, compute single value loss
                diff = jax.lax.stop_gradient(current_q) - current_value
                exp_diff = jax.lax.stop_gradient(jnp.exp(diff * self.beta))
                value_loss = jnp.mean((exp_diff - 1) * (diff))
        else:
            raise NotImplementedError

        total_loss = td_loss + value_loss
        
        info = {
            "loss": total_loss,
            "td_loss": td_loss,
            "value_loss": value_loss,
            "next_value": jnp.mean(next_value),
        }

        if self.network_type == "hierarchical_q":
            info["current_q"] = jnp.mean(jnp.array(current_q))
            info["advantages"] = jnp.mean(jnp.array([q - current_value for q in current_q]))
            for i, q_val in enumerate(current_q):
                info[f"current_q_{i}"] = jnp.mean(q_val)
                info[f"advantages_{i}"] = jnp.mean(q_val - current_value)
        else:
            info["current_q"] = jnp.mean(current_q)
            info["advantages"] = jnp.mean(current_q - current_value)
        
        return total_loss, info

    def get_weight(self, observation: _model.Observation, actions: _model.Actions) -> at.Float[at.Array, "b action_dim"]:  
        advantages = self.get_advantages(observation, actions)
        
        if self.network_type == "hierarchical_q":
            # For hierarchical Q, compute weights for each action group and broadcast to action dimensions
            weights = []
            for i, action_group in enumerate(self.hierarchical_actions):
                group_advantages = advantages[i]
                if self.method == 'expectile':
                    group_weight = jnp.abs(self.expectile - jnp.where(group_advantages < 0, 1.0, 0.0))
                elif self.method == 'exponential':
                    group_weight = self.beta * jnp.abs(jnp.exp(self.beta * group_advantages) - 1) / (jnp.abs(group_advantages) + 1e-4)
                    group_weight = jnp.clip(group_weight, 0.0, 2.0) # clip higher for some tasks
                else:
                    raise NotImplementedError
                
                # Broadcast weight to each action dimension in this group
                for action_idx in action_group:
                    # Create weight tensor for this specific action dimension
                    weights.append((action_idx, group_weight))
            
            # Sort by action index and concatenate to get [b, action_dim] tensor
            weights.sort(key=lambda x: x[0])
            weights_tensor = jnp.concatenate([w[1] for w in weights], axis=-1)
            # Pad weights_tensor to action_dim if needed
            if weights_tensor.shape[-1] < self.action_dim:
                pad_width = self.action_dim - weights_tensor.shape[-1]
                weights_tensor = jnp.pad(weights_tensor, ((0, 0), (0, pad_width)))
            else:
                weights_tensor = weights_tensor[:, :self.action_dim]
            
        else:
            # For standard Q, compute single weight and broadcast to all action dimensions
            if self.method == 'expectile':
                weight = jnp.abs(self.expectile - jnp.where(advantages < 0, 1.0, 0.0))
            elif self.method == 'exponential':
                weight = self.beta * jnp.abs(jnp.exp(self.beta * advantages) - 1) / jnp.abs(advantages)
            else:
                raise NotImplementedError
            
            # Broadcast weight to all action dimensions [b, action_dim]
            weights_tensor = jnp.broadcast_to(weight, (weight.shape[0], self.action_dim))
        
        return weights_tensor
    
    @override
    def sample_actions(self, rng: at.KeyArrayLike, observation: _model.Observation) -> _model.Actions:
        pass

