##
## (c) Anonymous authors (2026)
##
## > Synthetic informed POMDP
##
##

from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch

GAMMA = 0.99
NOISE_LEVEL = 0.0


@dataclass
class Episode:
    states: list
    obs: list
    actions: list
    rewards: list
    returns: list


class RandomSyntheticInformedPOMDP:
    """
        Random synthetic informed POMDP class

            num_states: number of states, default=10
            num_action: number of actions, default=4
            latent_dim: number of state features, default=5
            obs_dim: number of observation features, default=2
            fixed_latent_map: pre-defined latent feature map, default=None
            transitions: pre-defined transition probabilies, default=None
            reward_weights: pre-define reward_weights, default=None
            seed: seed to ensure reproducability of the random POMDP instance, default=None

        """

    def __init__(self, num_states=10, num_actions=4, latent_dim=5, obs_dim=2, fixed_latent_map=None, transitions=None,
                 reward_weights=None, seed=None):
        self.S = num_states
        self.A = num_actions
        self.obs_dim = obs_dim
        self.latent_dim = latent_dim

        # Setting the seed to ensure reproducability
        if seed is not None:
            np.random.seed(seed)
        else:
            np.random.seed()

        # Initialization of a latent state feature map
        if fixed_latent_map is not None:
            self.state_latent_map = fixed_latent_map
        else:
            self.state_latent_map = np.random.randn(self.S, latent_dim)

        # Initialization of observation mask: select a fixed subset of latent dims
        self.observation_mask = np.zeros(latent_dim)
        self.observation_mask[0:2] = 1.0
        self.observation_mask = self.observation_mask.astype(bool)

        # Initialization of the environment state
        self.state = None

        # Initialization of the transition function
        if transitions is not None:
            self.transitions = transitions
        else:
            self.transitions = np.zeros((self.S, self.A, self.S))

            for s in range(self.S):
                for a in range(self.A):
                    for s_prime in range(self.S):
                        # 25% chance to be non-zero
                        if np.random.rand() < 0.25:
                            self.transitions[s, a, s_prime] = np.random.rand()
                    # Ensure at least one non-zero
                    if np.sum(self.transitions[s, a]) == 0:
                        random_sp = np.random.randint(self.S)
                        self.transitions[s, a, random_sp] = 1.0

                    # Normalization
                    self.transitions[s, a] /= np.sum(self.transitions[s, a])

        # Initialization of the reward weights
        if reward_weights is not None:
            self.reward_weights = reward_weights
        else:
            self.reward_weights = np.random.uniform(-1.0, 1.0, size=latent_dim)

    def reset(self):
        """
        Reset environment
        """
        self.state = np.random.randint(0, self.S)
        return self.state

    def reward_function(self, state_idx):
        """
        Reward function
        """
        latent = self.state_latent_map[state_idx]
        return float(np.dot(latent, self.reward_weights))

    def step(self, action):
        """
        Environment step
        """
        # Transition to next state
        next_state = np.random.choice(self.S, p=self.transitions[self.state, action])
        self.state = next_state
        # Obtaining the reward
        reward = self.reward_function(next_state)

        return reward, self.state

    def generate_latent(self, state):
        """
        Generation of the privileged information based on the latent feature map
        """
        return self.state_latent_map[state]

    def generate_obs(self, state):
        """
        Generation of the observation conditioned on the latent feature map
        """
        full_latent = self.state_latent_map[state]

        assert full_latent.shape == self.observation_mask.shape, f"Shape mismatch: latent {full_latent.shape}, mask {self.observation_mask.shape}"

        obs = full_latent[self.observation_mask] + NOISE_LEVEL * np.random.normal(0.0, 1, size=self.obs_dim)
        return obs
