import torch.multiprocessing as mp

if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method("spawn", force=True)  # or 'forkserver'

import argparse
import datetime
import os
import random
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

import common_args
from dataset import CorruptedBanditDataset, Dataset, ImageDataset
from net import ImageTransformer, Transformer
from utils import (
    build_bandit_data_filename,
    build_bandit_model_filename,
    build_darkroom_data_filename,
    build_darkroom_model_filename,
    build_linear_bandit_data_filename,
    build_linear_bandit_model_filename,
    build_miniworld_data_filename,
    build_miniworld_model_filename,
    worker_init_fn,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == "__main__":
    os.makedirs("figs/loss", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    common_args.add_model_args(parser)
    common_args.add_train_args(parser)
    common_args.add_corrupt_train_args(parser)
    common_args.add_logging_args(parser)

    parser.add_argument("--seed", type=int, default=1)

    print("Start:", datetime.datetime.now(datetime.UTC))
    args = vars(parser.parse_args())
    print("Args: ", args)

    log_wandb = args["log"] == "wandb"

    env = args["env"]
    n_envs = args["envs"]
    n_hists = args["hists"]
    n_samples = args["samples"]
    H = args["H"]
    dim = args["dim"]
    state_dim = dim
    action_dim = dim
    arch = args["arch"]
    shuffle, lr, dropout, n_embd, n_layer, n_head = common_args.get_model_params_from_arch(args)
    var = args["var"]
    cov = args["cov"]
    num_epochs = args["num_epochs"]
    seed = args["seed"]
    lin_d = args["lin_d"]

    reward_type = args["reward"]

    tmp_seed = seed
    if seed == -1:
        tmp_seed = 0

    torch.manual_seed(tmp_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(tmp_seed)
        torch.cuda.manual_seed_all(tmp_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(tmp_seed)
    random.seed(tmp_seed)

    if shuffle and env == "linear_bandit":
        raise Exception(
            "Are you sure you want to shuffle on the linear bandit? Data collected from an adaptive algorithm in a stochastic setting can bias the learner if shuffled."
        )

    dataset_config = {
        "n_hists": n_hists,
        "n_samples": n_samples,
        "H": H,
        "dim": dim,
    }
    model_config = {
        "arch": arch,
        "shuffle": shuffle,
        "lr": lr,
        "dropout": dropout,
        "n_embd": n_embd,
        "n_layer": n_layer,
        "n_head": n_head,
        "n_envs": n_envs,
        "n_hists": n_hists,
        "n_samples": n_samples,
        "H": H,
        "dim": dim,
        "seed": seed,
    }
    if env == "bandit":
        state_dim = 1

        dataset_config.update({"var": var, "cov": cov, "type": "uniform"})
        path_train = build_bandit_data_filename(env, n_envs, dataset_config, mode=0)
        path_test = build_bandit_data_filename(env, n_envs, dataset_config, mode=1)

        model_config.update({"var": var, "cov": cov})
        filename = build_bandit_model_filename(env, model_config)

    elif env == "bandit_thompson":
        state_dim = 1

        dataset_config.update({"var": var, "cov": cov, "type": "bernoulli"})
        path_train = build_bandit_data_filename(env, n_envs, dataset_config, mode=0)
        path_test = build_bandit_data_filename(env, n_envs, dataset_config, mode=1)

        model_config.update({"var": var, "cov": cov})
        filename = build_bandit_model_filename(env, model_config)

    elif env == "linear_bandit":
        state_dim = 1

        dataset_config.update({"lin_d": lin_d, "var": var, "cov": cov})
        path_train = build_linear_bandit_data_filename(env, n_envs, dataset_config, mode=0)
        path_test = build_linear_bandit_data_filename(env, n_envs, dataset_config, mode=1)

        model_config.update({"lin_d": lin_d, "var": var, "cov": cov})
        filename = build_linear_bandit_model_filename(env, model_config)

    elif env.startswith("darkroom"):
        state_dim = 2
        action_dim = 5

        dataset_config.update({"rollin_type": "uniform"})
        path_train = build_darkroom_data_filename(env, n_envs, dataset_config, mode=0, reward_type=reward_type)
        path_test = build_darkroom_data_filename(env, n_envs, dataset_config, mode=1, reward_type=reward_type)

        filename = build_darkroom_model_filename(env, model_config, reward_type=reward_type)

    elif env == "miniworld":
        state_dim = 2  # direction vector is 2D, no position included
        action_dim = 4

        dataset_config.update({"rollin_type": "uniform"})

        increment = 5000
        starts = np.arange(0, n_envs, increment)
        starts = np.array(starts)
        ends = starts + increment - 1

        paths_train = []
        paths_test = []
        for start_env_id, end_env_id in zip(starts, ends):
            path_train = build_miniworld_data_filename(env, start_env_id, end_env_id, dataset_config, mode=0)
            path_test = build_miniworld_data_filename(env, start_env_id, end_env_id, dataset_config, mode=1)

            paths_train.append(path_train)
            paths_test.append(path_test)

        filename = build_miniworld_model_filename(env, model_config)
        print(f"Generate filename: {filename}")
    else:
        raise NotImplementedError

    config = {
        "H": H,
        "state_dim": state_dim,
        "action_dim": action_dim,
        "n_layer": n_layer,
        "n_embd": n_embd,
        "n_head": n_head,
        "shuffle": shuffle,
        "dropout": dropout,
        "test": False,
        "store_gpu": True,
        "var": var,
    }
    if env == "miniworld":
        config.update({"image_size": 25, "store_gpu": False})
        model = ImageTransformer(config).to(device)
    else:
        model = Transformer(config).to(device)

    params = {
        "batch_size": 64,
        "shuffle": True,
    }

    if log_wandb:
        import wandb

        run = wandb.init(
            project="robust-icrl",
            name=filename,
            config={
                "env": env,
                "n_envs": n_envs,
                "n_hists": n_hists,
                "n_samples": n_samples,
                "H": H,
                "dim": dim,
                "state_dim": state_dim,
                "action_dim": action_dim,
                "arch": arch,
                "shuffle": shuffle,
                "var": var,
                "cov": cov,
                "num_epochs": num_epochs,
                "seed": seed,
                "lin_d": lin_d,
                "reward_type": reward_type,
            },
        )

    if env == "miniworld":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

        params.update(
            {
                "num_workers": 16,
                "prefetch_factor": 2,
                "persistent_workers": True,
                "pin_memory": True,
                "batch_size": 64,
                "worker_init_fn": worker_init_fn,
            }
        )

        print("Loading miniworld data...")
        train_dataset = ImageDataset(paths_train, config, transform)
        test_dataset = ImageDataset(paths_test, config, transform)
        print("Done loading miniworld data")
    else:
        train_dataset = Dataset(path_train, config)
        test_dataset = Dataset(path_test, config)

    train_loader = DataLoader(train_dataset, **params)
    test_loader = DataLoader(test_dataset, **params)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

    test_loss = []
    test_loss_corrupted = []
    train_loss = []
    train_loss_corrupted = []

    print(f"Num train batches: {len(train_loader)}")
    print(f"Num test batches: {len(test_loader)}")

    for epoch in range(num_epochs):
        # EVALUATION
        # printw(f"Epoch: {epoch + 1}")
        start_time = time.time()
        with torch.no_grad():
            epoch_test_loss = 0.0
            epoch_test_loss_corrupted = 0.0

            for i, (batch, true_actions) in enumerate(test_loader):
                if i % 100 == 0 or i == len(test_loader) - 1:
                    print(f"Batch {i} of {len(test_loader)}", end="\r")
                if env == "miniworld":
                    batch, true_actions = (batch[0].to(device), batch[1].to(device)), true_actions.to(device)
                else:
                    batch, true_actions = batch.to(device), true_actions.to(device)
                pred_actions = model(batch)
                del batch

                true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                loss = loss_fn(pred_actions, true_actions)
                epoch_test_loss += loss.item() / H

        test_loss.append(epoch_test_loss / len(test_dataset))
        end_time = time.time()
        if log_wandb:
            run.log({"test/loss": test_loss[-1], "test/eval_time": end_time - start_time}, step=epoch)

        # TRAINING
        epoch_train_loss = 0.0
        start_time = time.time()

        for i, (batch, true_actions) in enumerate(train_loader):
            if i % 100 == 0 or i == len(train_loader) - 1:
                print(f"Batch {i} of {len(train_loader)}", end="\r")
            if env == "miniworld":
                batch, true_actions = (batch[0].to(device), batch[1].to(device)), true_actions.to(device)
            else:
                batch, true_actions = batch.to(device), true_actions.to(device)
            pred_actions = model(batch)
            del batch
            true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)
            true_actions = true_actions.reshape(-1, action_dim)
            pred_actions = pred_actions.reshape(-1, action_dim)

            optimizer.zero_grad()
            loss = loss_fn(pred_actions, true_actions)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item() / H

        train_loss.append(epoch_train_loss / len(train_dataset))
        end_time = time.time()
        if log_wandb:
            run.log({"train/loss": train_loss[-1], "train/eval_time": end_time - start_time}, step=epoch)
        # printw(f"\tTrain loss: {train_loss[-1]}")
        # printw(f"\tTrain time: {end_time - start_time}")

        # LOGGING
        if (epoch + 1) % 50 == 0 or (env == "linear_bandit" and (epoch + 1) % 10 == 0):
            torch.save(model.state_dict(), f"models/{filename}_epoch{epoch+1}.pt")

        # PLOTTING
        if (epoch + 1) % 10 == 0:
            print(f"Epoch: {epoch + 1}")
            print(f"Test Loss:        {test_loss[-1]}")
            print(f"Train Loss:       {train_loss[-1]}")
            print("\n")

            plt.yscale("log")
            plt.plot(train_loss[1:], label="Train")
            plt.plot(test_loss[1:], label="Test")
            plt.xlabel("epoch")
            plt.ylabel("loss")
            plt.legend()
            plt.savefig(f"figs/loss/{filename}_train_loss.png")
            plt.clf()

    torch.save(model.state_dict(), f"models/{filename}.pt")
    print("Done.")
    if log_wandb:
        run.finish()

    print("End:", datetime.datetime.now(datetime.UTC))
