import time
from dataclasses import asdict
from math import sqrt

import torch

from args import (
    DatasetConfig,
    LoggingConfig,
    PPOConfig,
    SeedConfig,
    get_model_name,
    parse_args_to_dataclass,
)
from mdp.darkroom_env import DarkroomEnv
from mdp.mdp_controller import MDPOptimalController, PPOController
from mdp.mdp_dataset import MDPDatasetTorch
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed

device = "cuda" if torch.cuda.is_available() else None


def main(logging_config: LoggingConfig, seed_config: SeedConfig, dataset_config: DatasetConfig, model_config: PPOConfig):
    set_seed(seed_config.seed)

    run_name = get_model_name(dataset_config, model_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(dataset_config),
                **asdict(model_config),
            },
            step_name="Batch",
        )
    else:
        logger = PrintLogger(run_name, "Batch")

    if dataset_config.env == "chain":
        state_dim = dataset_config.n_states
        action_dim = 2
    elif dataset_config.env == "darkroom":
        state_dim = 2
        action_dim = 5

    square_len = int(sqrt(dataset_config.n_states))
    n_episodes = model_config.n_episodes
    n_minibatches = 1
    n_envs = dataset_config.n_envs

    controller = PPOController(model_config, n_envs, dataset_config.context_len, dataset_config.n_states, state_dim, action_dim, device=device)
    envs = DarkroomEnv.sample(n_envs, dataset_config.context_len, square_len, device=device)
    opt = MDPOptimalController(envs.optimal_actions, n_envs, dataset_config.context_len, dataset_config.n_states, state_dim, action_dim)

    for epoch in range(model_config.n_epochs):
        train_reward = 0
        opt_reward = 0
        with torch.no_grad():
            datasets: list[MDPDatasetTorch] = []
            for _ in range(n_episodes):
                dataset = envs.deploy(controller)
                dataset_opt = envs.deploy(opt)
                datasets.append(dataset)

                train_reward += dataset.rewards_original.sum().item()
                opt_reward += dataset_opt.rewards_original.sum().item()

        start_time = time.time()
        metrics, _ = controller.update(datasets, n_minibatches)
        train_time = time.time() - start_time

        logger.log(
            {
                **metrics[0],  # TODO: fix if multiple metrics?
                "train/reward": train_reward / n_envs / n_episodes,
                "train/opt_reward": opt_reward / n_envs / n_episodes,
                "train/time": train_time,
            },
            step=epoch,
        )

        debug = False
        if debug and ((epoch + 1) % 100 == 0):
            with torch.no_grad():
                all_states = torch.stack(
                    [
                        torch.arange(square_len, device=device).repeat_interleave(square_len),
                        torch.arange(square_len, device=device).repeat(square_len),
                    ],
                    dim=1,
                ).expand((n_envs, -1, -1))
                _, values_all = controller.model(all_states.float())

                anim = envs.visualize_dataset(dataset, values=values_all, title="PPO")
                anim.save("test.mp4")
                # input("Press enter to continue")

    logger.finish()


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config = parse_args_to_dataclass((LoggingConfig, SeedConfig, DatasetConfig, PPOConfig))

    print(logging_config, seed_config, dataset_config, model_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
