import pickle
import time
from dataclasses import asdict

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from args import (
    DatasetConfig,
    LoggingConfig,
    ModelConfig,
    SeedConfig,
    get_model_name,
    get_model_save_name,
    parse_args_to_dataclass,
)
from bandit2.bandit_dataset import BanditDatasetTorch
from mdp.mdp_dataset import MDPDatasetTorch
from net import Transformer
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed

device = "cuda" if torch.cuda.is_available() else None


def main(logging_config: LoggingConfig, seed_config: SeedConfig, dataset_config: DatasetConfig, model_config: ModelConfig):
    run_name = get_model_name(dataset_config, model_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(dataset_config),
                **asdict(model_config),
            },
            step_name="Epoch",
        )
    else:
        logger = PrintLogger(run_name, "Epoch")

    set_seed(seed_config.seed)

    dataset_filename = dataset_config.get_summary()
    with open(f"datasets/{dataset_filename}_train.pkl", "rb") as f:
        print(f"Using '{f.name}' for training.")
        dataset_train: MDPDatasetTorch | BanditDatasetTorch = pickle.load(f)

    with open(f"datasets/{dataset_filename}_test.pkl", "rb") as f:
        print(f"Using '{f.name}' for testing.")
        dataset_test: MDPDatasetTorch | BanditDatasetTorch = pickle.load(f)

    batch_size = 64

    if dataset_config.env == "bandit":
        state_dim = 1
        action_dim = dataset_config.n_actions
    elif dataset_config.env == "chain":
        state_dim = dataset_config.n_states
        action_dim = 2
    elif dataset_config.env == "darkroom":
        state_dim = 2
        action_dim = 5

    dataset_train.shuffle = model_config.shuffle
    dataset_test.shuffle = model_config.shuffle

    dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)
    dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True)

    model = Transformer(model_config.get_params({"H": dataset_config.context_len, "state_dim": state_dim, "action_dim": action_dim})).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=model_config.lr, weight_decay=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

    for epoch in range(model_config.n_epochs):
        epoch_train_loss = 0.0
        epoch_test_loss = 0.0
        epoch_train_accuracy = 0
        epoch_test_accuracy = 0

        start_time = time.time()
        with torch.no_grad():
            for batch, true_actions in dataloader_test:
                pred_actions = model(batch)
                del batch

                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                loss: Tensor = loss_fn(pred_actions, true_actions)
                epoch_test_loss += loss.item() / dataset_config.context_len
                epoch_test_accuracy += torch.sum(pred_actions.argmax(dim=-1) == true_actions.argmax(dim=-1), dtype=torch.long)

        test_time = time.time() - start_time

        start_time = time.time()
        for batch, true_actions in dataloader_train:
            pred_actions = model(batch)
            del batch

            true_actions = true_actions.reshape(-1, action_dim)
            pred_actions = pred_actions.reshape(-1, action_dim)

            optimizer.zero_grad()
            loss: Tensor = loss_fn(pred_actions, true_actions)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item() / dataset_config.context_len
            epoch_train_accuracy += torch.sum(pred_actions.argmax(dim=-1) == true_actions.argmax(dim=-1), dtype=torch.long)

        train_time = time.time() - start_time

        logger.log(
            {
                "test/loss": epoch_test_loss / dataset_test.n_envs,
                "test/time": test_time,
                "test/accuracy": epoch_test_accuracy / dataset_test.n_envs / dataset_config.context_len,
                "train/loss": epoch_train_loss / dataset_train.n_envs,
                "train/time": train_time,
                "train/accuracy": epoch_train_accuracy / dataset_train.n_envs / dataset_config.context_len,
            },
            step=epoch,
        )

        if (epoch + 1) % 50 == 0 and (epoch + 1) != model_config.n_epochs:
            model_path = f"models/{get_model_save_name(run_name, seed_config, model_config.n_epochs, epoch + 1)}.pt"
            torch.save(model.state_dict(), model_path)
            print(f"Saved model to '{model_path}'.")

    model_path = f"models/{get_model_save_name(run_name, seed_config)}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"Saved model to '{model_path}'.")

    logger.finish()


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config = parse_args_to_dataclass((LoggingConfig, SeedConfig, DatasetConfig, ModelConfig))

    print(logging_config, dataset_config, model_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
