import colorama
from typing import Tuple

# Envs - Atari
from gymnasium.core import Env
from gymnasium import spaces
from envs.atari.build_env import build_single_atari_env, ActionSpace

# Feature Extractor - Cutie
from feature_extractor.cutie.build_feature_extractor import VisualMaskExtractor
from configs.atari_object_info import num_object_mapping_factory

# Buffer
from utils.replay_buffer import VisualReplayBuffer

# Agent - Visual input with object masks generated by Cutie (STORM)
from agents.visual.agent import Agent

# Params
from dataclasses import dataclass, field
@dataclass()
class WorldModelParams:
    transformer_max_length: int = 64
    transformer_hidden_dim: int = 256
    transformer_num_layers: int = 2
    transformer_num_heads: int = 4

@dataclass()
class PolicyParams:
    num_layers: int = 3
    hidden_dim: int = 512
    gamma: float = 0.985
    lambd: float = 0.95
    entropy_coef: float = 1e-3

@dataclass()
class Params:
    # use default_factory due to:
    # mutable default values are shared across all instances, potentially causing unexpected behavior and errors
    num_object_mapping: dict = field(default_factory=num_object_mapping_factory)
    num_objects: int = -1 # will be set in build()
    frame_skip: int = 4

    object_feature_dim: int = 2048

    buffer_max_length: int = int(1E5)
    buffer_warm_up: int = 1024

    latent_width: int = 16
    world_model: WorldModelParams = WorldModelParams()
    policy: PolicyParams = PolicyParams()

    max_sample_steps: int = int(1E5)
    min_train_ratio: float = 1.0 # train_steps/sample_steps, has nothing to do with batch_size, etc.
    max_train_ratio: float = 1.0 

    batch_size: int = 32
    batch_length: int = 32

    imagine_batch_size: int = 512
    imagine_context_length: int = 4
    imagine_batch_length: int = 16

    eval_context_length: int = 8

    save_every_steps: int = int(1E5) # TODO: debug


def build(env_name, seed) -> Tuple[Params, Env, ActionSpace, VisualMaskExtractor, VisualReplayBuffer, Agent]:
    params = Params()

    # Build Env >>>
    env, action_space = build_single_atari_env(env_name, seed)
    print("Possible actions: " + colorama.Fore.YELLOW + f"{action_space.choices_per_dim}" + colorama.Style.RESET_ALL)
    # <<< Build Env

    if "v5" in env_name:
        real_env_name = env_name.split("/")[1].split("-")[0] # "ALE/Boxing-v5" -> "Boxing", etc.
    elif "NoFrameskip-v4" in env_name:
        real_env_name = env_name.split("NoFrameskip")[0] # PongNoFrameskip-v4 -> Pong
    else:
        raise ValueError(f"Unknown env_name format: {env_name}")

    # Build Feature Extractor >>>
    params.num_objects = params.num_object_mapping[real_env_name]
    feature_extractor = VisualMaskExtractor(state_resolution=(64, 64), # STORM's default resolution
                                              label_folder=f"segmentation_masks/Atari/{real_env_name}", 
                                              num_objects=params.num_objects,
                                              model_size="small",
                                              expected_resolution=(160, 210),
                                              resolution_scale_factor=2,
                                              frame_scale_method="nearest")
    # <<< Build Feature Extractor

    # Buffer
    replay_buffer = VisualReplayBuffer(
        obs_shape=(3+params.num_objects, 64, 64),
        action_dim=1,
        num_envs=1,
        max_length=params.buffer_max_length,
        warmup_length=params.buffer_warm_up,
        store_on_gpu=True
    )

    # Agent
    agent = Agent(
        input_channels=3+params.num_objects,
        action_dims=[action_space.choices_per_dim],
        num_objects=params.num_objects,
        latent_width=params.latent_width,
        transformer_max_length=params.world_model.transformer_max_length,
        transformer_hidden_dim=params.world_model.transformer_hidden_dim,
        transformer_num_layers=params.world_model.transformer_num_layers,
        transformer_num_heads=params.world_model.transformer_num_heads,
        policy_num_layers=params.policy.num_layers,
        policy_hidden_dim=params.policy.hidden_dim,
        gamma=params.policy.gamma,
        lambd=params.policy.lambd,
        entropy_coef=params.policy.entropy_coef
    ).cuda()

    return params, env, action_space, feature_extractor, replay_buffer, agent