import colorama
from typing import Tuple

# Envs - Atari
from envs.hollow_knight.env_wrapper import HKEnv, IndependentActionSpace, LocalAbstractHKEnv
from envs.hollow_knight.build_env import build_hollow_knight_env

# Feature Extractor - Cutie
from feature_extractor.cutie.build_feature_extractor import VisualMaskExtractor
from configs.hollow_knight_object_info import num_object_mapping_factory

# Buffer
from utils.replay_buffer import VisualReplayBuffer

# Agent - Visual input with object masks generated by Cutie (STORM)
from agents.visual.agent import Agent

# Params
from dataclasses import dataclass, field
@dataclass()
class WorldModelParams:
    transformer_max_length: int = 64
    transformer_hidden_dim: int = 256
    transformer_num_layers: int = 2
    transformer_num_heads: int = 4

@dataclass()
class PolicyParams:
    num_layers: int = 3
    hidden_dim: int = 512
    gamma: float = 0.985
    lambd: float = 0.95
    entropy_coef: float = 1e-3

@dataclass()
class Params:
    # use default_factory due to:
    # mutable default values are shared across all instances, potentially causing unexpected behavior and errors
    num_object_mapping: dict = field(default_factory=num_object_mapping_factory)
    num_objects: int = -1 # will be set in build()
    frame_skip: int = 1

    object_feature_dim: int = 2048

    buffer_max_length: int = int(1E5) # memory not sufficient for 2E5 float obs+mask
    buffer_warm_up: int = 4096

    latent_width: int = 16
    world_model: WorldModelParams = WorldModelParams()
    policy: PolicyParams = PolicyParams()

    max_sample_steps: int = int(2E5) + 20
    min_train_ratio: float = 1.0 # train_steps/sample_steps, has nothing to do with batch_size, etc.
    max_train_ratio: float = 1.0 

    batch_size: int = 32
    batch_length: int = 32

    imagine_batch_size: int = 512
    imagine_context_length: int = 4
    imagine_batch_length: int = 16

    eval_context_length: int = 8

    save_every_steps: int = int(4E4) # TODO: debug


def build(env_name, seed) -> Tuple[Params, LocalAbstractHKEnv, IndependentActionSpace, VisualMaskExtractor, VisualReplayBuffer, Agent]:
    params = Params()

    # HollowKnight/HornetProtector -> HornetProtector
    boss_name = env_name.split("/")[-1]
    print("Boss Name: " + colorama.Fore.YELLOW + f"{boss_name}" + colorama.Style.RESET_ALL)

    # Build Env >>>
    height, width = 1275, 711 # Hollow Knight original resolution
    env, action_space = build_hollow_knight_env(boss_name, obs_size=(height, width), target_fps=9)
    print("action_dim: " + colorama.Fore.YELLOW + f"{action_space.dim}" + colorama.Style.RESET_ALL)
    print("action_choices: " + colorama.Fore.YELLOW + f"{action_space.choices_per_dim}" + colorama.Style.RESET_ALL)
    # <<< Build Env

    # Build Feature Extractor >>>
    params.num_objects = params.num_object_mapping[boss_name]
    feature_extractor = VisualMaskExtractor(state_resolution=(64, 64), # STORM's default resolution
                                              label_folder=f"segmentation_masks/HollowKnight/{boss_name}",
                                              num_objects=params.num_objects,
                                              model_size="small",
                                              expected_resolution=(height, width),
                                              resolution_scale_factor=0.675,
                                              frame_scale_method="bilinear")
    # <<< Build Feature Extractor

    # Buffer
    replay_buffer = VisualReplayBuffer(
        obs_shape=(3+params.num_objects, 64, 64),
        action_dim=action_space.dim,
        num_envs=1,
        max_length=params.buffer_max_length,
        warmup_length=params.buffer_warm_up,
        store_on_gpu=True
    )

    # Agent
    agent = Agent(
        input_channels=3+params.num_objects,
        action_dims=[action_space.choices_per_dim]*action_space.dim, # [2] * 7 for Hollow Knight
        num_objects=params.num_objects,
        latent_width=params.latent_width,
        transformer_max_length=params.world_model.transformer_max_length,
        transformer_hidden_dim=params.world_model.transformer_hidden_dim,
        transformer_num_layers=params.world_model.transformer_num_layers,
        transformer_num_heads=params.world_model.transformer_num_heads,
        policy_num_layers=params.policy.num_layers,
        policy_hidden_dim=params.policy.hidden_dim,
        gamma=params.policy.gamma,
        lambd=params.policy.lambd,
        entropy_coef=params.policy.entropy_coef
    ).cuda()

    return params, env, action_space, feature_extractor, replay_buffer, agent