"""Function environment built on batch of static GP samples."""

import torch
from torch import Tensor
import numpy as np
from data.sampler import gp_sampler
from data.function_sampling import sample_full_inputs_from_subspaces
from data.data_masking import generate_dim_mask, restore_full_dims
from utils.dataclasses import SamplerConfig
from utils.moo import norm_compute_hv_batch
from typing import Optional, List, Tuple
from einops import repeat

SIGMA = 0.0
NUM_SAMPLES = 1


def sample_from_gps(
    x_dim: int,
    y_dim: int,
    batch_size: int,
    d: int,
    sampler_config: SamplerConfig,
    use_grid_sampling: bool,
    use_factorized_policy: bool,
    device: str,
    zero_mean: bool = True,
):
    """Sample datapoints from GP.

    Returns:
        x: Input points, [b, m, dx]
        y: Output points, [b, m, dy]
        chunks: Chunks of input points, [d, dx]
        chunk_mask: Mask for valid chunks, [num_chunks, dx]
    """
    x_mask = torch.ones((x_dim,), dtype=torch.bool, device=device)
    x, chunks, chunk_mask = sample_full_inputs_from_subspaces(
        d=d,
        x_mask=x_mask,
        input_bounds=sampler_config.x_range,
        use_grid_sampling=use_grid_sampling,
        use_factorized_policy=use_factorized_policy,
    )
    x = x.unsqueeze(0).expand(batch_size, -1, -1)  # [b, m, dx]
    y = gp_sampler(
        y_dim=y_dim,
        x=x,
        x_range=sampler_config.x_range,
        sampler_list=sampler_config.sampler_list,
        sampler_weights=sampler_config.sampler_weights,
        data_kernel_type_list=sampler_config.data_kernel_type_list,
        sample_kernel_weights=sampler_config.sample_kernel_weights,
        lengthscale_range=sampler_config.lengthscale_range,
        std_range=sampler_config.std_range,
        min_rank=sampler_config.min_rank,
        max_rank=sampler_config.max_rank,
        p_iso=sampler_config.p_iso,
        jitter=sampler_config.jitter,
        max_tries=sampler_config.max_tries,
        standardize=sampler_config.standardize,
        grid=use_grid_sampling,
        device=device,
    )
    if zero_mean:
        y = y - y.mean(dim=1, keepdim=True)

    del x_mask
    return x, y, chunks, chunk_mask


class GPSampleFunction:
    """Function environment built on batch of static GP samples.

    Some attributes:
        chunks: [d, max_x_dim]
        chunk_mask: [num_chunks, max_x_dim]
        y_mins: [B, max_y_dim]
        y_maxs: [B, max_y_dim]
        max_hv: [B]
        max_hv_norm: [B]
    """

    def __init__(
        self,
        batch_size: int,
        x_dim: int,
        y_dim: int,
        max_x_dim: int,
        max_y_dim: int,
        dim_scatter_mode: str,
        d: int,
        sampler_config: SamplerConfig,
        use_grid_sampling: bool,
        use_factorized_policy: bool,
        num_samples: int = NUM_SAMPLES,
        online_generate: bool = True,
        device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        restore_full_dim_later: bool = True,
        zero_mean: bool = True,
        mask_out_new: bool = False,
        **kwargs,
    ):
        """NOTE only for joint policy."""
        assert not use_factorized_policy, f"Factorized policy not supported yet."
        self.restore_full_dim_later = restore_full_dim_later
        self.online_generate = online_generate
        self.mask_out_new = mask_out_new   
        self.x_dim = x_dim  
        self.y_dim = y_dim 

        if online_generate:
            self.max_x_dim = max_x_dim
            self.max_y_dim = max_y_dim

            x, y, chunks_, chunk_mask_ = sample_from_gps(
                x_dim=x_dim,
                y_dim=y_dim,
                batch_size=batch_size,
                d=d,
                sampler_config=sampler_config,
                use_grid_sampling=use_grid_sampling,
                use_factorized_policy=use_factorized_policy,
                device=device,
                zero_mean=zero_mean,
            )

            # Generate padding masks: [max_x_dim], [max_y_dim]
            self.x_mask, _ = generate_dim_mask(
                k=x_dim,
                max_dim=max_x_dim,
                dim_scatter_mode=dim_scatter_mode,
                device=device,
            )
            self.y_mask, _ = generate_dim_mask(
                k=y_dim,
                max_dim=max_y_dim,
                dim_scatter_mode=dim_scatter_mode,
                device=device,
            )

            if restore_full_dim_later:
                # Keep reduced dimensions: [*, x_dim], [*, y_dim]
                self.chunks_, self.chunk_mask_ = chunks_, chunk_mask_
            else:
                # Restore full dimensions: [*, max_x_dim], [*, max_x_dim]
                x, self.chunks_, self.chunk_mask_ = self.prepare_full_dimensions(
                    mask=self.x_mask, tensors=[x, chunks_, chunk_mask_]
                )
                y = self.prepare_full_dimensions(mask=self.y_mask, tensors=[y])[0]
        else:
            if use_factorized_policy:
                raise NotImplementedError(
                    "Factorized policy not implemented for offline data"
                )

            raise NotImplementedError(
                "Offline/batched generation is not implemented because online generation is sufficiently fast."
            )

        self._x_base = x
        self._y_base = y
        self._num_samples = num_samples
        self.num_points = x.shape[1]
        self.batch_size = x.shape[0] * num_samples

        self.max_hv, _, _ = norm_compute_hv_batch(
            solutions=self._y,
            minimum=torch.min(self._y, dim=1).values,
            maximum=torch.max(self._y, dim=1).values,
            y_mask=None if self.restore_full_dim_later else self.y_mask,
            normalize=False,
        )
        self.max_hv_norm, _, _ = norm_compute_hv_batch(
            solutions=self._y,
            minimum=torch.min(self._y, dim=1).values,
            maximum=torch.max(self._y, dim=1).values,
            y_mask=None if self.restore_full_dim_later else self.y_mask,
            normalize=True,
        )

    @property
    def _y(self):
        return self.repeat_along_batch(self._y_base, self._num_samples)

    @property
    def _x(self):
        return self.repeat_along_batch(self._x_base, self._num_samples)

    @property
    def y_mins(self):
        y_mins_ = torch.min(self._y, dim=1).values  # [B, dim]
        return restore_full_dims(data=y_mins_, mask=self.y_mask, dim=-1)

    @property
    def y_maxs(self):
        y_maxs_ = torch.max(self._y, dim=1).values
        return restore_full_dims(data=y_maxs_, mask=self.y_mask, dim=-1)

    @property
    def chunks(self):
        return restore_full_dims(data=self.chunks_, mask=self.x_mask, dim=-1)

    @property
    def chunk_mask(self):
        return restore_full_dims(data=self.chunk_mask_, mask=self.x_mask, dim=-1)

    @staticmethod
    def repeat_along_batch(tensor: Tensor, num_repeat: int):
        """Repeat each element in the batch dimension `num_repeat` times."""
        if tensor.ndim < 2:
            raise ValueError("Expected at least 2 dimensions: (batch + features).")
        if num_repeat == 1:
            return tensor
        expanded = tensor.unsqueeze(1).expand(-1, num_repeat, *tensor.shape[1:])
        reshaped = expanded.reshape(-1, *tensor.shape[1:])
        return reshaped

    @staticmethod
    def prepare_full_dimensions(mask: Tensor, tensors: List[Tensor]) -> List[Tensor]:
        full_tensors = [
            restore_full_dims(data=tensor, mask=mask, dim=-1) for tensor in tensors
        ]
        return full_tensors

    @staticmethod
    def update_context(new: Tensor, old: Optional[Tensor]) -> Tensor:
        """Update context with new observations.
        - If old is None, return new [B, num_new, DY]
        - Else concatenate old and new, return [B, num_old + num_new, DY]
        """
        if old is None:
            return new.clone()
        else:
            batch_size_old, _, dim_old = old.shape
            batch_size_new, _, dim_new = new.shape

            assert batch_size_new == batch_size_old
            assert dim_new == dim_old

            return torch.cat((old, new), dim=1)

    def compute_hv(self, solutions: Tensor, normalize: bool = False) -> np.ndarray:
        """Compute hypervolume.

        Args:
            solutions: [B, N, dy_max]
            normalize: Whether to normalize sols and ref_points before computing hv

        Returns:
            reward: shape [B]
            sols_tfm: (Optionally transformed) solutions, [B, N, dy_max]
            reward_ref_points: reference points, [B, dy_max]
        """
        reward, sols_tfm, ref_points = norm_compute_hv_batch(
            solutions=solutions,
            minimum=self.y_mins,
            maximum=self.y_maxs,
            y_mask=self.y_mask,
            normalize=normalize,
        )

        return reward, sols_tfm, ref_points

    def compute_regret(
        self, solutions: Tensor, regret_type: str = "ratio"
    ) -> np.ndarray:
        """Compute regret.

        Args:
            solutions: [B, N, dy_max]
            regret_type: in ["value", "ratio", "norm_ratio"]

        Returns:
            regret: [B]
        """
        if regret_type in ["value", "ratio"]:
            hv_np = self.compute_hv(solutions, normalize=False)[0]
            regret_np = self.max_hv - hv_np
            if regret_type == "ratio":
                norm_term = self.max_hv
                norm_term[norm_term == 0.0] = 1.0  # Avoid division by zero
                regret_np /= norm_term
        elif regret_type in ["norm_ratio", "norm_ratio_dim"]:
            hv_np = self.compute_hv(solutions, normalize=True)[0]
            norm_term = self.max_hv_norm
            norm_term[norm_term == 0.0] = 1.0
            regret_np = (norm_term - hv_np) / norm_term
            if regret_type == "norm_ratio_dim":
                regret_np = regret_np / self.y_dim
        else:
            raise ValueError(f"Unknown regret type: {regret_type}")

        return regret_np

    @staticmethod
    def batch_gather(tensor, dim, index, full_dim_mask=None) -> Tensor:
        """Gather tensor along dim by index, optionally restore full data dimensions."""
        index_expanded = index.expand(-1, -1, tensor.shape[-1])
        tensor_gathered = torch.gather(tensor, dim=dim, index=index_expanded)
        if full_dim_mask is not None:
            tensor_gathered = restore_full_dims(
                data=tensor_gathered, mask=full_dim_mask, dim=-1
            )
        return tensor_gathered

    def step(
        self,
        index_new: Tensor,
        x_ctx: Optional[Tensor] = None,
        y_ctx: Optional[Tensor] = None,
        compute_hv: bool = True,
        compute_regret: bool = True,
        regret_type: str = "ratio",
    ) -> Tuple[Tensor, Tensor, Tensor, Optional[np.ndarray], Optional[np.ndarray]]:
        """Evaluate function input `x_new`, update context, optionally compute reward / regret.

        Args:
            index_new: Index of new datapoints in training data, [B, num_new, 1]
            x_ctx: Optional context input points, shape [B, num_ctx, dx_max]
            y_ctx: Optional context output points, shape [B, num_ctx, dy_max]
            compute_reward: compute reward from choosing `x_new` if True
            compute_regret: compute regret from choosing `x_new` if True
            regret_type: Type of regret to compute, defaults to "ratio"

        Returns:
            x_ctx: Updated context input points, shape [B, num_ctx + num_new, dx_max]
            y_ctx: Updated context output points, shape [B, num_ctx + num_new, dy_max]
            reward (np.ndarray): Reward from choosing `x_new`, shape [B] or None if not computed
            regret (np.ndarray): Regret from choosing `x_new`, shape [B] or None if not computed
        """
        x_new = self.batch_gather(
            tensor=self._x,
            dim=1,
            index=index_new,
            full_dim_mask=self.x_mask if self.restore_full_dim_later else None,
        )
        y_new = self.batch_gather(
            tensor=self._y,
            dim=1,
            index=index_new,
            full_dim_mask=self.y_mask if self.restore_full_dim_later else None,
        )

        x_ctx = GPSampleFunction.update_context(new=x_new, old=x_ctx)
        y_ctx = GPSampleFunction.update_context(new=y_new, old=y_ctx)

        reward, regret = None, None
        if compute_hv:
            reward = self.compute_hv(y_ctx)[0]
        if compute_regret:
            regret = self.compute_regret(y_ctx, regret_type)

        # if self.mask_out_new: 
        #     # Mask out new index in x, y, chunks to avoid re-sampling them
        #     batch_size, num_new, _ = index_new.shape
        #     chunks_exp = self.chunks[None, None, :, :].expand(batch_size, -1, -1, -1)
        #     index_new_exp = index_new.expand(-1, -1, self.chunks_exp.shape[-1]).unsqueeze(1)
            
        #     # index_new: [batch_size, num_new, 1] -> [batch_size, num_new, dx] -> [batch_size, 1, num_new, dx]
        #      # [num_datapoints, dx] ->  [batch_size, 1, num_datapoints, dx]

        return x_ctx, y_ctx, reward, regret

    def init(
        self,
        num_initial_points: int,
        regret_type: str,
        compute_hv: bool = True,
        compute_regret: bool = True,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        """Sample initial points and evaluate function at them.

        Returns:
            x_ctx: [B, num_initial_points, dx_max]
            y_ctx: [B, num_initial_points, dy_max]
            reward (np.ndarray): [B] or None
            regret (np.ndarray): [B] or None
        """
        indices = torch.randperm(self.num_points, device=device)[:num_initial_points]
        indices = repeat(indices, "n -> b n 1", b=self.batch_size)

        x_ctx, y_ctx, reward, regret = self.step(
            index_new=indices,
            x_ctx=None,
            y_ctx=None,
            compute_hv=compute_hv,
            compute_regret=compute_regret,
            regret_type=regret_type,
        )

        return x_ctx, y_ctx, reward, regret
