import os
from typing import Optional

from ruamel.yaml import YAML
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from constants import TQDM_OPTIONS
from evaluation.area_under_roc_curve import calculate_area_under_roc_curve
from util import yaml_object_to_string
from . import _training_init
from ._training_config import TrainingConfig


def train(config: Optional[TrainingConfig] = None, checkpoint_path: Optional[str] = None):
    """
    Trains a model as specified in the given configuration, or continues training from the given checkpoint.

    Either `config` or `checkpoint_path` must be provided.
    If both are provided, `config` is ignored.
    """
    if config is None and checkpoint_path is None:
        raise ValueError("Either config or checkpoint_path must be provided")

    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, weights_only=False)
        config: TrainingConfig = YAML().load(checkpoint["training_config"])

    data_loaders = _training_init.data_loaders(config)
    model, device = _training_init.model(config)

    optimiser = AdamW(model.parameters(), lr=config.initial_lr)
    lr_scheduler = ReduceLROnPlateau(optimiser, patience=config.lr_scheduler_patience, mode="max", verbose=True)

    loss_function = _training_init.loss_function(config)

    if checkpoint_path is not None:
        # restore training state from the checkpoint
        model.load_state_dict(checkpoint["model"])
        optimiser.load_state_dict(checkpoint["optimiser"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        start_epoch = checkpoint["epoch"] + 1
        log_dir = os.path.dirname(checkpoint_path)
        print("Continuing training at epoch", start_epoch)
    else:
        # start training from scratch
        start_epoch = 1
        log_dir = _training_init.log_dir(config)

    for epoch in range(start_epoch, config.num_epochs + 1):
        for chunk, (train_loader, val_loader) in enumerate(data_loaders, start=1):
            model.train()

            loss_sum = 0
            correct_edges = 0
            total_edges = 0

            for graph in tqdm(train_loader, **TQDM_OPTIONS):
                graph.to(device)
                optimiser.zero_grad()

                edge_predictions = model(graph)

                loss = loss_function(edge_predictions, graph)

                loss.backward()
                optimiser.step()

                loss_sum += loss.item()
                predicted_labels = edge_predictions > 0
                correct_edges += (predicted_labels == graph.y).sum().item()
                total_edges += predicted_labels.size(0)

            loss_mean = loss_sum / total_edges
            train_accuracy = correct_edges / total_edges
            area_under_roc_curve = calculate_area_under_roc_curve(model, val_loader, device)
            lr_scheduler.step(area_under_roc_curve)
            print(
                f"[Epoch {epoch}.{chunk}] "
                f"Train Loss: {loss_mean}, "
                f"Train Accuracy: {train_accuracy}, "
                f"Val Area under ROC Curve: {area_under_roc_curve}"
            )

            if (((epoch - 1) * config.num_evaluations_per_epoch) + chunk) % config.checkpoint_interval == 0:
                checkpoint = {
                    "training_config": yaml_object_to_string(config),
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "optimiser": optimiser.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                }
                torch.save(checkpoint, os.path.join(log_dir, f"{epoch}.{chunk}-checkpoint.pt"))


if __name__ == "__main__":
    from constants import TRAINING_CONFIG_FILE

    training_config = YAML().load(TRAINING_CONFIG_FILE)
    train(training_config)
