"""Functions for policy learning."""

import time
from einops import repeat
import numpy as np
from torch import Tensor
import torch
from utils.types import FloatListOrNestedOrTensor
from typing import Optional, Tuple
from data.function_sampling import (
    get_sample_indices_from_chunk_indices,
    sample_factorized_subspaces,
)
from TAMO.model.tamo import TAMO

GAMMA = 0.98
SIGMA = 0.0


def sample_start_points(
    chunks: Tensor,
    y: Tensor,
    num_start: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """(Old) Randomly sample starting points.

    Args:
        chunks (Tensor): Input data chunks, shape [B, n, d, dx_max].
        y (Tensor): Function value at inputs, shape [B, m, dy_max].
        num_start (int): Number of starting points to sample.

    Returns:
        x_ctx (Tensor): Sampled input data, shape [B, num_start, dx_max].
        y_ctx (Tensor): Sampled function values, shape [B, num_start, dy_max].
        indices_chunk (Tensor): Sampled indices from chunks, shape [B, num_start, n].
        indices_sample (Tensor): Sampled indices from samples, shape [B, num_start, 1].
    """
    raise ValueError(f"Old")
    # B, n, d, dx_max = chunks.shape
    # dy_max = y.shape[-1]

    # # Randomly sample indices: [B, num_start, n] and [B, num_start, 1]
    # indices_chunk = torch.randint(0, d, (B, num_start, n), device=chunks.device)
    # indices_sample = get_sample_indices_from_chunk_indices(indices_chunk, n, d)

    # # Find starting points: [B, num_start, dx_max] and [B, num_start, dy_max]
    # x_ctx = get_sample_from_chunk_indices(chunks, indices_chunk)
    # y_ctx = torch.gather(y, dim=1, index=indices_sample.expand(B, num_start, dy_max))

    # return x_ctx, y_ctx, indices_chunk, indices_sample


def _get_cumulative_rewards(reward: Tensor, discount_factor: float = 0.98) -> Tensor:
    """Compute future cumulative rewards from step rewards."""
    B, H = reward.shape
    cumulative_rewards = torch.zeros_like(reward)

    for t in reversed(range(H)):
        # t = H - 1: cumulative_rewards[:, H - 1] = reward[:, H - 1]
        # t = H - 2: cumulative_rewards[:, H - 2] = reward[:, H - 2] + discount_factor * reward[:, H - 1]
        if t == H - 1:
            # the last step
            cumulative_rewards[:, t] = reward[:, t]
        else:
            cumulative_rewards[:, t] = reward[:, t] + discount_factor * reward[:, t + 1]

    # [B, H]
    return cumulative_rewards


def compute_policy_loss(
    step_rewards: Tensor,  # [B, H]
    log_probs: Tensor,  # [B, H]
    eps: float = np.finfo(np.float32).eps.item(),
    use_cumulative_r: bool = False,
    discount_factor: float = GAMMA,
    batch_standardize: bool = True,
    clip_rewards: bool = True,
    sum_over_tra: bool = False,
    batch_first: bool = True,
    entropy: Optional[Tensor] = None,
    entropy_coeff: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    """Compute policy gradient loss.

    Args:
        step_rewards: immediate rewards of actions, [B, H]
        log_probs: log probabilities of actions, [B, H]
        eps: small value to avoid division by zero
        use_cumulative_r: whether to use cumulative rewards, default False
            - True: R_t = r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + ...
            - False: R_t = r_t * gamma^t
            where gamma is the discount factor
        discount_factor: default 0.98
        batch_standardize: whether to standardize rewards over batch dimension or horizon dimension
        clip_rewards: whether to clip rewards to zero if they are not informative
        sum_over_tra: whether to sum over trajectories or take mean, default False

    Returns: loss [1], step_rewards [B, H]
    """
    if not batch_first:
        # [H, B] -> [B, H]
        step_rewards = step_rewards.transpose(0, 1)
        log_probs = log_probs.transpose(0, 1)

    B, H = step_rewards.shape
    assert log_probs.shape == (B, H), f"{log_probs.shape}"

    # No gradients from rewards
    step_rewards = step_rewards.detach()

    # Set non-informative reward to zero
    if clip_rewards:
        # [1, 0, 3, 2, 4] -> [1, 1, 3, 3, 4]
        step_rewards_cummax = torch.cummax(step_rewards, dim=-1).values

        # e.g. [1, 1, 3, 3, 4] * [T, F, T, F, T] = [1, 0, 3, 0, 4]
        is_info = step_rewards == step_rewards_cummax
        step_rewards *= (is_info).float()

    # Compute cumulative or discounted immediate rewards
    if use_cumulative_r:
        reward = _get_cumulative_rewards(
            reward=step_rewards, discount_factor=discount_factor
        )
    else:
        discounts = discount_factor ** torch.arange(H, device=step_rewards.device)
        reward = discounts * step_rewards

    # reward standardization over batch dim or trajectory dim
    if batch_standardize:
        assert B > 1
        reward = (reward - reward.mean(dim=0, keepdim=True)) / (
            reward.std(dim=0, keepdim=True) + eps
        )
    else:
        assert H > 1
        reward = (reward - reward.mean(dim=-1, keepdim=True)) / (
            reward.std(dim=-1, keepdim=True) + eps
        )

    loss = -reward * log_probs

    # NOTE Aggregate by sum or mean: [B]
    # when H varies, SHOULD take mean
    loss = loss.sum(dim=-1) if sum_over_tra else loss.mean(dim=-1)

    if entropy is not None and entropy_coeff > 0.0:
        if not batch_first: 
            entropy = entropy.transpose(0, 1)
        
        assert entropy.shape == (B, H)
        entropy = entropy.detach()
        if sum_over_tra:
            entropy = entropy.sum(dim=-1)
        else:
            entropy = entropy.mean(dim=-1) 
        loss = loss - entropy_coeff * entropy
    return loss.mean(), step_rewards


def compute_orthogonality_loss(E: Tensor) -> Tensor:
    """Compute a loss that measures how far E's rows is from being orthogonal.s"""
    raise ValueError("Old")
    # n, d = E.shape
    # dot = E @ E.T  # [n, n]

    # # orthogonal matrix
    # I = torch.eye(n, device=E.device)

    # # Normalized difference
    # return ((dot - I) ** 2).sum() / n**2


def select_next_query(
    model: TAMO,
    x_mask: Tensor,
    y_mask: Tensor,
    x_ctx: Tensor,
    y_ctx: Tensor,
    input_bounds: FloatListOrNestedOrTensor,
    d: int,
    t: int,
    T: int,
    use_grid_sampling: bool,
    use_fixed_query_set: bool,
    use_factorized_policy: bool,
    use_time_budget: bool,
    y_mask_tar: Optional[Tensor] = None,
    q_chunk: Optional[Tensor] = None,
    q_chunk_mask: Optional[Tensor] = None,
    evaluate: bool = False,
    read_cache: bool = False,
    write_cache: bool = False,
    logit_mask: Optional[Tensor] = None,
    epsilon: float = 1.0,
    auto_clear_cache: bool = True,
):
    """TAMO selects the next query point based on the current context and query set.

    Args:
        model: TAMO model
        x_ctx: [B, num_ctx, max_x_dim]
        y_ctx: [B, num_ctx, max_y_dim]
        x_mask: [max_x_dim] | [B, max_x_dim] | [B, num_ctx, max_x_dim]
        y_mask: [max_y_dim] | [B, max_y_dim] | [B, num_ctx, max_y_dim]
        input_bounds: Input bounds, list / nested list / tensor
        t: Current time step
        d: Number of points in each subspace
        opt_config: OptimizationConfig
        q_chunk: Optional query chunks, [d, max_x_dim]
        q_chunk_mask: Optional query chunk masks, [n, max_x_dim]
        evaluate: If True, keep logits for evaluation

    Returns:
        x: [B, 1, max_x_dim]
        indices, log_probs, entropies: [B]
        logits: [B, n, d]
        q_chunk: [d, max_x_dim]
        q_chunk_mask: [n, max_x_dim]
        infer_time: float
    """
    B, _, dx_max = x_ctx.shape

    # Generate query set if not provided or not fixed
    if q_chunk is None or q_chunk_mask is None or not use_fixed_query_set:
        q_chunk, q_chunk_mask = sample_factorized_subspaces(
            d=d,
            x_mask=x_mask,  # NOTE
            input_bounds=input_bounds,
            use_grid_sampling=use_grid_sampling,
            use_factorized_policy=use_factorized_policy,
        )

    # Expand dimensions if necessary
    def _expand_query_chunk_mask(q_chunk_mask, b: int):
        assert q_chunk_mask.ndim in (2, 3)
        if q_chunk_mask.ndim == 2:
            q_chunk_mask = repeat(q_chunk_mask, "n dim -> b n dim", b=b)
        return q_chunk_mask

    q_chunk_mask_exp = _expand_query_chunk_mask(q_chunk_mask, B)

    def _expand_query_chunk(q_chunk, n: int, b: int):
        assert q_chunk.ndim in (2, 4)
        if q_chunk.ndim == 2:
            q_chunk = repeat(q_chunk, "d dim -> b n d dim", b=b, n=n)
        return q_chunk

    n = q_chunk_mask_exp.shape[1]
    q_chunk_exp = _expand_query_chunk(q_chunk, n, B)

    def _expand_dimension_mask(mask, b: int, n: Optional[int] = None):
        if mask.ndim == 1:
            mask = repeat(mask, "dim -> b dim", b=b)

        if mask.ndim == 3:
            assert n is not None
            assert mask.shape == (b, n, mask.shape[2])
            return mask

        return mask

    # Expand dimension masks for context datapoints
    x_mask = _expand_dimension_mask(x_mask, B, x_ctx.shape[1])
    y_mask = _expand_dimension_mask(y_mask, B, y_ctx.shape[1])
    if y_mask_tar is not None:
        y_mask_tar = _expand_dimension_mask(y_mask_tar, B)

    # Run model inference
    if evaluate and logit_mask is None:
        logit_mask = torch.ones((B, n, d), device=x_ctx.device).bool()

    t0 = time.time()
    results = model.action(
        x_ctx=x_ctx,
        y_ctx=y_ctx,
        x_dim_mask=x_mask,
        y_dim_mask=y_mask,
        y_dim_mask_tar=y_mask_tar,
        q_chunk=q_chunk_exp,
        q_dim_mask=q_chunk_mask_exp,
        t=t,
        T=T,
        use_budget=use_time_budget,
        evaluate=evaluate,
        read_cache=read_cache,
        write_cache=write_cache,
        auto_clear_cache=auto_clear_cache,
        logit_mask=logit_mask,
        epsilon=epsilon,
    )
    infer_time = time.time() - t0

    assert results[1].shape == (B, n), f"indice.shape={results[1].shape}, B={B}, n={n}"
    assert results[2].shape == (B, n), f"logp.shape={results[2].shape}, B={B}, n={n}"
    assert results[3].shape == (B, n), f"ent.shape={results[3].shape}, B={B}, n={n}"

    # NOTE mask out
    if logit_mask is not None:
        logit_mask = mask_out_used_chunks(
            logit_mask=logit_mask, used_indices=results[1]
        )

    # Collapse across chunks: [B, n] -> [B]
    indices = get_sample_indices_from_chunk_indices(chunk_indices=results[1], n=n, d=d)
    indices = indices.squeeze(-1)
    logp = results[2].sum(-1)
    entropy = results[3].sum(-1)

    return (
        results[0],
        indices,
        logp,
        entropy,
        results[4],
        q_chunk,
        q_chunk_mask,
        infer_time,
        logit_mask,
    )


def mask_out_used_chunks(
    logit_mask: Tensor,  # [B, n, d]
    used_indices: Tensor,  # [B, n]
) -> Tensor:
    B, n = used_indices.shape
    d = logit_mask.shape[-1]
    assert (
        n == 1
    ), f"Only support full policy (n=1) for now, since: masking out element in a chunk will lead to masking out all datapoints from that chunk"

    logit_mask = logit_mask.bool().view(B * n, -1)  # [B * n, d]
    logit_mask[torch.arange(B * n), used_indices.view(-1)] = False
    return logit_mask.view(B, n, d)
