import time
from math import sqrt

import torch
from torch.cuda import is_available as is_cuda_available

from args import (
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    SeedConfig,
    get_model_name,
    parse_args_to_dataclass,
)
from mdp.chain_env import ChainEnv
from mdp.darkroom_env import DarkroomEnv
from mdp.mdp_controller import MDPTransformerController
from net import Transformer
from util.seed import set_seed


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, seed_config: SeedConfig):
    device = "cuda" if is_cuda_available() else None
    model_name = get_model_name(dataset_config, model_config)

    n_envs_vis = 1
    n_steps_model = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len

    set_seed(seed_config.seed)

    if dataset_config.env == "chain":
        n_actions = 2
        state_dim = dataset_config.n_states

        envs = ChainEnv.sample(n_envs_vis, n_steps_eval, dataset_config.n_states, dataset_config.variance, device=device)
    elif dataset_config.env == "darkroom":
        n_actions = 5
        state_dim = 2
        square_len = int(sqrt(dataset_config.n_states))

        envs = DarkroomEnv.sample(n_envs_vis, n_steps_eval, square_len, device=device)
    else:
        raise NotImplementedError()

    if eval_config.epoch is not None:
        model_path = f"models/{model_name}_epoch{eval_config.epoch}.pt"
    else:
        model_path = f"models/{model_name}.pt"
    model = Transformer(model_config.get_params({"H": n_steps_model, "state_dim": state_dim, "action_dim": n_actions})).to(device)
    model.test = True
    print(f"Loading model '{model_path}'...")
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.eval()

    controller = MDPTransformerController(model, n_envs_vis, n_steps_model, dataset_config.n_states, state_dim, n_actions, sample=True, device=device)
    # controller = MDPOptimalController(n_envs_eval, n_steps, dataset_config.n_states, state_dim, n_actions, envs.optimal_actions, device=device)

    dataset = envs.deploy(controller, omit_optimal_actions=False, pbar_desc="Visualize")
    # if isinstance(envs, ChainEnv):
    #     dataset.optimal_actions = envs.get_optimal_actions_per_state(dataset.states)

    # assert dataset.optimal_actions is not None
    # nans = torch.ones((1, n_steps, 1), device=device) * torch.nan
    # print(torch.cat((dataset.states, nans, dataset.optimal_actions), dim=-1))

    anim = envs.visualize_dataset(dataset)

    if eval_config.epoch is not None:
        video_path = f"video/eval_clean_{model_name}_epoch{eval_config.epoch}.mp4"
    else:
        video_path = f"video/eval_clean_{model_name}.mp4"

    # video_path = f"video/eval_clean_optimal.mp4"

    anim.save(video_path)
    print(f"Saved video to '{video_path}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, seed_config = parse_args_to_dataclass((DatasetConfig, ModelConfig, EvalConfig, SeedConfig))

    print(dataset_config, model_config, eval_config, seed_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, seed_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
