import time
from dataclasses import asdict
from math import sqrt

import torch

from args import (
    DatasetConfig,
    LoggingConfig,
    NPGConfig,
    SeedConfig,
    get_model_name,
    parse_args_to_dataclass,
)
from mdp.darkroom_env import DarkroomEnv
from mdp.mdp_controller import MDPNPGController, MDPOptimalController
from mdp.mdp_dataset import MDPDatasetTorch
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed

device = "cuda" if torch.cuda.is_available() else None


def main(logging_config: LoggingConfig, seed_config: SeedConfig, dataset_config: DatasetConfig, model_config: NPGConfig):
    set_seed(seed_config.seed)

    run_name = get_model_name(dataset_config, model_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(dataset_config),
                **asdict(model_config),
            },
            step_name="Epoch",
        )
    else:
        logger = PrintLogger(run_name, "Epoch")

    if dataset_config.env == "chain":
        state_dim = dataset_config.n_states
        action_dim = 2
    elif dataset_config.env == "darkroom":
        state_dim = 2
        action_dim = 5

    square_len = int(sqrt(dataset_config.n_states))
    n_episodes = model_config.n_episodes
    n_envs = dataset_config.n_envs

    controller = MDPNPGController(model_config, n_envs, dataset_config.context_len, state_dim, dataset_config.n_states, action_dim, sample=True, device=device)
    envs = DarkroomEnv.sample(n_envs, dataset_config.context_len, square_len, device=device)
    opt = MDPOptimalController(envs.optimal_actions, n_envs, dataset_config.context_len, dataset_config.n_states, state_dim, action_dim)

    for epoch in range(model_config.n_epochs):
        train_reward = 0
        opt_reward = 0
        with torch.no_grad():
            datasets: list[MDPDatasetTorch] = []
            for episode in range(n_episodes):
                clear_dataset = False
                if episode == 0:
                    clear_dataset = True
                dataset = envs.deploy(controller, clear_dataset=clear_dataset)
                dataset_opt = envs.deploy(opt)
                datasets.append(dataset)

                train_reward += dataset.rewards_original.sum().item()
                opt_reward += dataset_opt.rewards_original.sum().item()

        start_time = time.time()
        metrics, aux = controller.update(datasets)
        train_time = time.time() - start_time

        debug = False
        if debug and ((epoch + 1) % 25 == 0 or epoch == 0):
            anim = envs.visualize_dataset(dataset, advantages=aux["advantages"], values=aux["values"], title="NPG")
            anim.save("test.mp4")
            input("Enter to continue")

        logger.log(
            {
                **metrics[0],  # TODO: fix if multiple metrics?
                "train/reward": train_reward / n_envs / n_episodes,
                "train/opt_reward": opt_reward / n_envs / n_episodes,
                "train/time": train_time,
            },
            step=epoch,
        )
        # if (epoch + 1) % 50 == 0 and (epoch + 1) != model_config.n_epochs:
        #     torch.save(model.state_dict(), f"models/{run_name}_epoch{epoch+1}.pt")
        #     print(f"Saved model to 'models/{run_name}_epoch{epoch+1}.pt'.")

    # torch.save(model.state_dict(), f"models/{run_name}.pt")
    # print(f"Saved model to 'models/{run_name}.pt'.")

    logger.finish()


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config = parse_args_to_dataclass((LoggingConfig, SeedConfig, DatasetConfig, NPGConfig))

    print(logging_config, seed_config, dataset_config, model_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
