import json
import os
import random
import time
import uuid

import torch
from torch.utils.data import DataLoader

import wandb
from adversarial_superposition.constants import (
    DATA_DIR,
    DEVICE,
    FLOAT_PRECISION_MAP,
    MODEL_DIR,
    RESULTS_DIR,
)
from adversarial_superposition.modulo.utils.logger import MetricsLogger
from adversarial_superposition.modulo.utils.utils import (
    Config,
    cross_entropy_float16,
    cross_entropy_float32,
    cross_entropy_float64,
    evaluate,
    get_dataset,
    get_model,
    get_optimizer,
    process_metrics,
    stablemax_cross_entropy,
)


def train_network(experiment_key, config, return_grokking_epoch=False, verbose=False):
    torch.set_num_threads(5)

    random.seed(config.seed)
    torch.manual_seed(config.seed)

    print(f"Saving the experiments in the directory: {experiment_key}")
    os.makedirs(RESULTS_DIR / f"toy_models/{experiment_key}", exist_ok=True)
    os.makedirs(MODEL_DIR / f"toy_models/{experiment_key}", exist_ok=True)
    os.makedirs(DATA_DIR / f"toy_models/{experiment_key}", exist_ok=True)

    train_precision = FLOAT_PRECISION_MAP[config.train_precision]

    train_dataset, test_dataset = get_dataset(config)
    if config.full_batch:
        config.batch_size = len(train_dataset)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    torch.save(
        train_dataset, DATA_DIR / f"toy_models/{experiment_key}/last_train_loader.pt"
    )
    torch.save(
        test_dataset, DATA_DIR / f"toy_models/{experiment_key}/last_test_loader.pt"
    )

    config.lr = config.lr / (config.alpha**2)

    model = get_model(config)
    logger = MetricsLogger(config.num_epochs, config.log_frequency)

    base_optimizer = get_optimizer(model, config)

    optimizer = base_optimizer

    cross_entropy_function = {
        16: cross_entropy_float16,
        32: cross_entropy_float32,
        64: cross_entropy_float64,
    }

    loss_functions = {
        "cross_entropy": cross_entropy_function[config.softmax_precision],
        "stablemax": stablemax_cross_entropy,
    }
    loss_function = loss_functions[config.loss_function]
    save_model_checkpoints = range(0, config.num_epochs, config.log_frequency)
    saved_models = {epoch: None for epoch in save_model_checkpoints}

    all_data = (
        train_dataset.dataset.data[train_dataset.indices].to(DEVICE).to(train_precision)
    )
    all_targets = train_dataset.dataset.targets[train_dataset.indices].to(DEVICE).long()

    all_test_data = (
        test_dataset.dataset.data[test_dataset.indices].to(DEVICE).to(train_precision)
    )
    all_test_targets = (
        test_dataset.dataset.targets[test_dataset.indices].to(DEVICE).long()
    )

    loss = torch.inf
    start_time = time.time()
    model.to(DEVICE).to(train_precision)

    grokking_model = None

    for epoch in range(config.num_epochs):
        permutation = torch.randperm(all_data.size(0))
        shuffled_data = all_data[permutation]
        shuffled_targets = all_targets[permutation]
        model.train()
        optimizer.zero_grad()
        output = model(shuffled_data)
        if config.use_transformer:
            output = output[:, -1]
        output = output * config.alpha
        loss = loss_function(output, shuffled_targets)
        loss.backward()
        optimizer.step()

        if epoch % logger.log_frequency == 0:
            logger.log_metrics(
                model=model,
                epoch=epoch,
                save_model_checkpoints=save_model_checkpoints,
                saved_models=saved_models,
                all_data=shuffled_data,
                all_targets=shuffled_targets,
                all_test_data=all_test_data,
                all_test_targets=all_test_targets,
                args=config,
                loss_function=loss_function,
            )

            if verbose:
                print(f"Epoch {epoch}: Training loss: {loss.item():.4f}")

            if epoch > 0:
                if verbose:
                    print(
                        f"Time taken for the last {config.log_frequency} epochs: {(time.time() - start_time) / 60:.2f} min"
                    )
            start_time = time.time()

            test_acc = logger.metrics_df[
                (logger.metrics_df["metric_name"] == "accuracy")
                & (logger.metrics_df["input_type"] == "test")
            ].iloc[-1]["value"]
            train_acc = logger.metrics_df[
                (logger.metrics_df["metric_name"] == "accuracy")
                & (logger.metrics_df["input_type"] == "train")
            ].iloc[-1]["value"]

            wandb.log(process_metrics(logger.metrics_df), step=epoch)

            if verbose:
                print(f"Train accuracy: {train_acc}; Test accuracy: {test_acc}")

            if not grokking_model and test_acc > 0.99:
                grokking_model = epoch

    model.eval().to("cpu")
    test_loss, test_accuracy = evaluate(model, test_loader)
    print(f"Test set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}")
    config.lr = config.lr

    torch.save(
        saved_models,
        MODEL_DIR / f"toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt",
    )
    torch.save(optimizer, MODEL_DIR / f"toy_models/{experiment_key}/last_optimizer.pt")
    print(f"Saving to {MODEL_DIR / f'toy_models/{experiment_key}/last_optimizer.pt'}")

    logger.metrics_df.to_csv(
        RESULTS_DIR / f"toy_models/{experiment_key}/metrics.csv", index=False
    )

    with open(RESULTS_DIR / f"toy_models/{experiment_key}/config.json", "w") as f:
        json.dump(vars(config), f, indent=4)

    print(f"Saving run: {experiment_key}")

    if return_grokking_epoch:
        return grokking_model


if __name__ == "__main__":
    wandb.init(
        project="toy_models_of_addition",
        tags=["train"],
        mode="disabled",
    )

    experiment_key = f"{uuid.uuid4().hex[:8]}"

    grokking_cfg = Config(
        seed=3,
        lr=0.001,
        num_epochs=10_000,
        input_size=113,
        modulo=113,
        log_frequency=50,
        train_fraction=0.3,
        weight_decay=4.0,
    )

    wandb.log(grokking_cfg.dict())

    train_network(
        experiment_key=experiment_key,
        config=grokking_cfg,
        verbose=True,
    )
