from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from omegaconf import OmegaConf
from torch_geometric.data import Data
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from graphsmodel.models import LP
from graphsmodel.utils import (
    get_appropriate_edge_index,
    get_model,
    get_optimizer,
    weights_to_cpu,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def get_logits(data, model):
    """
    Compute the logits for the given data using the specified model.

    Args:
        data (torch.Tensor): The input data.
        model (torch.nn.Module): The model used for computing the logits.

    Returns:
        torch.Tensor: The computed logits.
    """
    data = data.to(device)
    model.to(device)

    model.eval()
    if isinstance(model, LP):
        logits = model(y=data.y, mask=data.train_mask)
    else:
        edge_index = get_appropriate_edge_index(data)
        logits = model(x=data.x, edge_index=edge_index)

    data = data.to("cpu")
    model.to("cpu")
    return logits.detach().cpu()


@torch.no_grad()
def evaluate(model, loader, evaluator):
    model.eval()

    logits, y = [], []
    for data in loader:
        data = data.to(device)
        edge_index = get_appropriate_edge_index(data)
        pred = model(
            x=data.x,
            edge_index=edge_index,
            edge_attr=data.edge_attr,
            batch=data.batch,
        )

        logits.append(pred.cpu())
        y.append(data.y.cpu())

    y = torch.cat(y, dim=0)
    logits = torch.cat(logits, dim=0)

    return (
        logits,
        y,
        evaluator.eval({"y_true": y, "y_pred": logits})["rocauc"],
    )


def graph_level_step(loader, model, mode="eval", optimizer=None):
    """
    Perform a single step of training or evaluation.

    Args:
        data (torch_geometric.data.Data): The input data.
        model (torch.nn.Module): The model to be trained or evaluated.
        mask (torch.Tensor): The mask indicating the training nodes.
        mode (str, optional): The mode of operation. Either "train" or "eval". Defaults to "eval".
        optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Required if mode is "train". Defaults to None.

    Returns:
        float: The loss value as a scalar.
    """
    assert (mode == "eval" and optimizer is None) or (
        mode == "train" and optimizer is not None
    )
    if mode == "train":
        model.train()
    else:
        model.eval()

    total_loss = total_examples = 0
    for data in loader:
        data = data.to(device)
        if mode == "train":
            optimizer.zero_grad()
        edge_index = get_appropriate_edge_index(data)
        out = model(
            x=data.x,
            edge_index=edge_index,
            edge_attr=data.edge_attr,
            batch=data.batch,
        )
        loss = F.binary_cross_entropy_with_logits(out, data.y.to(torch.float))
        if mode == "train":
            loss.backward()
            optimizer.step()

        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs

    return total_loss / total_examples


def graph_level_subset_training(
    train_loader,
    val_loader,
    model,
    optimizer,
    patience,
    max_epochs,
    pbar=True,
):
    """
    Trains a model on a subset of data using the specified optimizer.

    Args:
        data: The input data.
        model: The model to be trained.
        optimizer: The optimizer used for training.
        train_mask: The mask indicating the subset of data to be used for training.
        val_mask: The mask indicating the subset of data to be used for validation.
        patience: The number of epochs to wait for improvement in validation loss before early stopping.
        max_epochs: The maximum number of epochs to train.

    Returns:
        A tuple containing the trained model, training loss trace, validation loss trace, best epoch,
        and logits.
    """
    model.to(device)

    train_trace = []
    val_trace = []
    best_loss = float("Inf")
    if pbar:
        progress = tqdm(range(max_epochs))
    else:
        progress = range(max_epochs)
    for epoch in progress:
        loss = graph_level_step(
            loader=train_loader, model=model, optimizer=optimizer, mode="train"
        )
        val_loss = graph_level_step(loader=val_loader, model=model, mode="eval")
        train_trace.append(loss)
        val_trace.append(val_loss)

        if val_loss < best_loss:
            best_epoch = epoch
            best_loss = val_loss
            best_state = {key: value.cpu() for key, value in model.state_dict().items()}

        # Early stopping
        if epoch - best_epoch >= patience:
            break

    model.load_state_dict(best_state)

    return model, train_trace, val_trace, best_epoch


def graph_level_training(
    train_loader,
    val_loader,
    test_loader,
    evaluator,
    model,
    optimizer,
    patience,
    max_epochs,
    pbar=True,
):
    """
    Trains a model using the provided data, model, optimizer, and training parameters.

    Args:
        model_id (str): Identifier for the model.
        data (torch.Tensor): Input data for training.
        model (torch.nn.Module): Model to be trained.
        optimizer (torch.optim.Optimizer): Optimizer used for training.
        train_mask (torch.Tensor): Mask indicating the training samples.
        val_mask (torch.Tensor): Mask indicating the validation samples.
        test_mask (torch.Tensor): Mask indicating the test samples.
        patience (int): Number of epochs to wait for improvement in validation loss before early stopping.
        max_epochs (int): Maximum number of epochs for training.

    Returns:
        Tuple: A tuple containing the following elements:
            - model_id (str): Identifier for the model.
            - model_state_dict (dict): State dictionary of the trained model.
            - logits (torch.Tensor): Logits produced by the trained model.
            - train_trace (list): List of training loss values during training.
            - val_trace (list): List of validation loss values during training.
            - best_epoch (int): Epoch at which the best validation loss was achieved.
            - test_acc (float): Accuracy of the model on the test set.
    """
    model, train_trace, val_trace, best_epoch = graph_level_subset_training(
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        optimizer=optimizer,
        patience=patience,
        max_epochs=max_epochs,
        pbar=pbar,
    )

    _, _, metric = evaluate(model, test_loader, evaluator)

    return {
        "weights": weights_to_cpu(model.state_dict()),
        # "logits": logits,
        # "y_true": y_true,
        "train_trace": train_trace,
        "val_trace": val_trace,
        "best_epoch": best_epoch,
        "test_metric": metric,
    }


def node_level_step(data, model, mask, mode="eval", optimizer=None):
    """
    Perform a single step of training or evaluation.

    Args:
        data (torch_geometric.data.Data): The input data.
        model (torch.nn.Module): The model to be trained or evaluated.
        mask (torch.Tensor): The mask indicating the training nodes.
        mode (str, optional): The mode of operation. Either "train" or "eval". Defaults to "eval".
        optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Required if mode is "train". Defaults to None.

    Returns:
        float: The loss value as a scalar.
    """
    assert (mode == "eval" and optimizer is None) or (
        mode == "train" and optimizer is not None
    )
    if mode == "train":
        model.train()
        optimizer.zero_grad()
    else:
        model.eval()

    edge_index = get_appropriate_edge_index(data)
    out = model(data.x, edge_index)
    loss = F.cross_entropy(out[mask], data.y[mask])

    if mode == "train":
        loss.backward()
        optimizer.step()

    return loss.item()


def node_level_subset_training(
    data,
    model,
    optimizer,
    patience,
    max_epochs,
    pbar=True,
):
    """
    Trains a model on a subset of data using the specified optimizer.

    Args:
        data: The input data.
        model: The model to be trained.
        optimizer: The optimizer used for training.
        train_mask: The mask indicating the subset of data to be used for training.
        val_mask: The mask indicating the subset of data to be used for validation.
        patience: The number of epochs to wait for improvement in validation loss before early stopping.
        max_epochs: The maximum number of epochs to train.

    Returns:
        A tuple containing the trained model, training loss trace, validation loss trace, best epoch,
        and logits.
    """
    data = data.to(device)
    model.to(device)

    train_trace = []
    val_trace = []
    best_loss = float("Inf")
    best_epoch = 0

    # train only if training nodes are available
    if data.train_mask.sum() > 0:
        # if no validation nodes available, overfit the model for 100 epochs (a 2-layers gcn takes roughly 100 epochs to converge)
        if data.val_mask.sum() == 0:
            max_epochs = 100

        if pbar:
            progress = tqdm(range(max_epochs))
        else:
            progress = range(max_epochs)

        for epoch in progress:

            loss = node_level_step(
                data=data,
                model=model,
                optimizer=optimizer,
                mask=data.train_mask,
                mode="train",
            )
            train_trace.append(loss)

            if data.val_mask.sum() > 0:
                val_loss = node_level_step(
                    data=data, model=model, mask=data.val_mask, mode="eval"
                )
                val_trace.append(val_loss)

                if val_loss <= best_loss:
                    best_epoch = epoch
                    best_loss = val_loss
                    best_state = {
                        key: value.cpu() for key, value in model.state_dict().items()
                    }

                # Early stopping
                if epoch - best_epoch >= patience:
                    break

        if data.val_mask.sum() > 0:
            model.load_state_dict(best_state)
        else:
            best_epoch = epoch

    logits = get_logits(data=data, model=model)
    return model, train_trace, val_trace, best_epoch, logits


def node_level_training(
    data,
    model,
    optimizer,
    train_mask,
    val_mask,
    test_mask,
    patience,
    max_epochs,
    pbar=True,
):
    """
    Trains a model using the provided data, model, optimizer, and training parameters.

    Args:
        model_id (str): Identifier for the model.
        data (torch.Tensor): Input data for training.
        model (torch.nn.Module): Model to be trained.
        optimizer (torch.optim.Optimizer): Optimizer used for training.
        train_mask (torch.Tensor): Mask indicating the training samples.
        val_mask (torch.Tensor): Mask indicating the validation samples.
        test_mask (torch.Tensor): Mask indicating the test samples.
        patience (int): Number of epochs to wait for improvement in validation loss before early stopping.
        max_epochs (int): Maximum number of epochs for training.

    Returns:
        Tuple: A tuple containing the following elements:
            - model_id (str): Identifier for the model.
            - model_state_dict (dict): State dictionary of the trained model.
            - logits (torch.Tensor): Logits produced by the trained model.
            - train_trace (list): List of training loss values during training.
            - val_trace (list): List of validation loss values during training.
            - best_epoch (int): Epoch at which the best validation loss was achieved.
            - test_acc (float): Accuracy of the model on the test set.
    """
    model, train_trace, val_trace, best_epoch, logits = node_level_subset_training(
        data=data,
        model=model,
        optimizer=optimizer,
        train_mask=train_mask,
        val_mask=val_mask,
        patience=patience,
        max_epochs=max_epochs,
        pbar=pbar,
    )

    test_acc = (logits.argmax(dim=-1) == data.y.cpu()).float()[test_mask].mean()
    return {
        "weights": weights_to_cpu(model.state_dict()),
        # "logits": logits,
        # "y_true": data.y,
        # "train_logits_mask": train_mask,
        # "val_logits_mask": val_mask,
        # "test_logits_mask": test_mask,
        "train_trace": train_trace,
        "val_trace": val_trace,
        "best_epoch": best_epoch,
        "test_acc": test_acc,
    }


def get_subset_data(cfg, data, subset):
    if cfg.data.subset_mode == "train":
        train_mask = data.train_mask.clone().numpy()
        train_mask[data.train_mask.numpy()] = subset

        val_mask = data.val_mask.clone().numpy()
        test_mask = data.test_mask.clone().numpy()
    elif cfg.data.subset_mode == "val":
        val_mask = data.val_mask.clone().numpy()
        val_mask[data.val_mask.numpy()] = subset

        train_mask = data.train_mask.clone().numpy()
        test_mask = data.test_mask.clone().numpy()
    elif cfg.data.subset_mode == "test":
        test_mask = data.test_mask.clone().numpy()
        test_mask[data.test_mask.numpy()] = subset

        train_mask = data.train_mask.clone().numpy()
        val_mask = data.val_mask.clone().numpy()
    elif cfg.data.subset_mode == "mixed":
        train_mask = np.zeros(data.num_nodes, dtype=bool)
        val_mask = np.zeros(data.num_nodes, dtype=bool)
        test_mask = np.zeros(data.num_nodes, dtype=bool)

        train_mask[data.train_mask.numpy()] = subset[data.train_mask.numpy()]
        val_mask[data.val_mask.numpy()] = subset[data.val_mask.numpy()]
        test_mask[data.test_mask.numpy()] = subset[data.test_mask.numpy()]
    else:
        raise ValueError("Invalid subset mode.")

    if cfg.task.induced_subgraph:
        nodes_to_keep = torch.from_numpy(
            (train_mask | val_mask | test_mask).nonzero()[0]
        )
        sub_data = data.subgraph(nodes_to_keep)
    else:
        sub_data = data.clone()
        sub_data.train_mask = train_mask
        sub_data.val_mask = val_mask
        sub_data.test_mask = test_mask

    return sub_data, train_mask, val_mask, test_mask


def train_subset(
    subset_idx: int,
    subset: np.array,
    cfg: OmegaConf,
    data: Data,
    logits_on_data: Optional[bool] = False,
):
    sub_data, train_mask, val_mask, test_mask = get_subset_data(cfg, data, subset)

    num_classes = data.y.max().item() + 1
    sub_y = torch.full((data.num_nodes,), -1, dtype=torch.long)

    mapped_sub_logits = torch.full(
        (cfg.train.n_models, data.num_nodes, num_classes),
        float("NaN"),
        dtype=torch.float32,
    )
    if logits_on_data:
        logits = []

    for model_seed in range(cfg.train.n_models):
        seed_everything(model_seed)
        model = get_model(cfg, sub_data, num_classes=num_classes)
        if "_target_" in cfg.train.optimizer:
            optimizer = get_optimizer(cfg.train.optimizer, model)
            model, train_trace, val_trace, best_epoch, sub_logits = (
                node_level_subset_training(
                    data=sub_data,
                    model=model,
                    optimizer=optimizer,
                    patience=cfg.train.patience,
                    max_epochs=cfg.train.epochs,
                    pbar=False,
                )
            )
        else:
            if isinstance(model, LP):
                edge_index = get_appropriate_edge_index(sub_data)
                fit_args = {
                    "x": sub_data.x,
                    "edge_index": edge_index,
                    "y": sub_data.y,
                    "train_mask": sub_data.train_mask,
                }
                model.fit(**fit_args)
            else:
                if sub_data.train_mask.sum() > 0:
                    edge_index = get_appropriate_edge_index(sub_data)
                    fit_args = {
                        "x": sub_data.x,
                        "edge_index": edge_index,
                        "y": sub_data.y,
                        "train_mask": sub_data.train_mask,
                    }
                    model.fit(**fit_args)
            sub_logits = get_logits(sub_data, model)

        mapped_sub_logits[model_seed, train_mask] = sub_logits[sub_data.train_mask]
        mapped_sub_logits[model_seed, val_mask] = sub_logits[sub_data.val_mask]
        mapped_sub_logits[model_seed, test_mask] = sub_logits[sub_data.test_mask]

        if logits_on_data:
            logits.append(get_logits(data, model))

        if "wandb" in cfg.log and subset_idx is not None:
            wandb.log(
                {
                    f"subset_{subset_idx}/trends": wandb.plot.line_series(
                        xs=[
                            torch.arange(train_trace),
                            torch.arange(val_trace),
                        ],
                        ys=[train_trace, val_trace],
                        keys=["train", "val"],
                        title="Losses",
                        xname="epoch",
                    )
                }
            )
            wandb.log({f"subset_{subset_idx}/best_epoch": best_epoch})

    sub_y[train_mask] = sub_data.y[sub_data.train_mask]
    sub_y[val_mask] = sub_data.y[sub_data.val_mask]
    sub_y[test_mask] = sub_data.y[sub_data.test_mask]

    if logits_on_data:
        return (
            subset,
            mapped_sub_logits,
            sub_y,
            torch.stack(logits),
        )

    return (
        subset,
        mapped_sub_logits,
        sub_y,
    )
