import colorama
from typing import Tuple

# Envs - Atari
from envs.hollow_knight.env_wrapper import HKEnv, IndependentActionSpace, LocalAbstractHKEnv
from envs.hollow_knight.build_env import build_hollow_knight_env

# Feature Extractor - Visual
from feature_extractor.visual.build_feature_extractor import ResizeObservation

# Buffer
from utils.replay_buffer import VisualReplayBuffer

# Agent - Visual input (STORM)
from agents.visual.agent import Agent

# Params
from dataclasses import dataclass, field
@dataclass()
class WorldModelParams:
    transformer_max_length: int = 64
    transformer_hidden_dim: int = 256
    transformer_num_layers: int = 2
    transformer_num_heads: int = 4

@dataclass()
class PolicyParams:
    num_layers: int = 3
    hidden_dim: int = 512
    gamma: float = 0.985
    lambd: float = 0.95
    entropy_coef: float = 1e-3

@dataclass()
class Params:
    num_objects: int = -1 # will be set in build()
    frame_skip: int = 1

    object_feature_dim: int = 2048

    buffer_max_length: int = int(2E5) # TODO: approximately equal to Atari100k's sample time, not steps
    buffer_warm_up: int = 4096

    latent_width: int = 16
    world_model: WorldModelParams = WorldModelParams()
    policy: PolicyParams = PolicyParams()

    max_sample_steps: int = int(2E5) + 20
    min_train_ratio: float = 1.0 # train_steps/sample_steps, has nothing to do with batch_size, etc.
    max_train_ratio: float = 1.0 

    batch_size: int = 32
    batch_length: int = 32

    imagine_batch_size: int = 512
    imagine_context_length: int = 4
    imagine_batch_length: int = 16

    eval_context_length: int = 8

    save_every_steps: int = int(4E4) # TODO: debug


def build(env_name, seed) -> Tuple[Params, LocalAbstractHKEnv, IndependentActionSpace, ResizeObservation, VisualReplayBuffer, Agent]:
    params = Params()

    # HollowKnight/HornetProtector -> HornetProtector
    boss_name = env_name.split("/")[-1]
    print("Boss Name: " + colorama.Fore.YELLOW + f"{boss_name}" + colorama.Style.RESET_ALL)

    # Build Env >>>
    height, width = 1275, 711 # Hollow Knight original resolution
    env, action_space = build_hollow_knight_env(boss_name, obs_size=(height, width), target_fps=9)
    print("action_dim: " + colorama.Fore.YELLOW + f"{action_space.dim}" + colorama.Style.RESET_ALL)
    print("action_choices: " + colorama.Fore.YELLOW + f"{action_space.choices_per_dim}" + colorama.Style.RESET_ALL)
    # <<< Build Env

    # Build Feature Extractor >>>
    feature_extractor = ResizeObservation(state_resolution=(64, 64))

    # <<< Build Feature Extractor

    # Buffer
    replay_buffer = VisualReplayBuffer(
        obs_shape=(3, 64, 64),
        action_dim=action_space.dim,
        num_envs=1,
        max_length=params.buffer_max_length,
        warmup_length=params.buffer_warm_up,
        store_on_gpu=True
    )

    # Agent
    agent = Agent(
        input_channels=3,
        action_dims=[action_space.choices_per_dim]*action_space.dim, # [2] * 7 for Hollow Knight
        num_objects=params.num_objects,
        latent_width=params.latent_width,
        transformer_max_length=params.world_model.transformer_max_length,
        transformer_hidden_dim=params.world_model.transformer_hidden_dim,
        transformer_num_layers=params.world_model.transformer_num_layers,
        transformer_num_heads=params.world_model.transformer_num_heads,
        policy_num_layers=params.policy.num_layers,
        policy_hidden_dim=params.policy.hidden_dim,
        gamma=params.policy.gamma,
        lambd=params.policy.lambd,
        entropy_coef=params.policy.entropy_coef
    ).cuda()

    return params, env, action_space, feature_extractor, replay_buffer, agent