import os
import numpy as np
from tqdm import tqdm
import wandb
from sklearn.metrics import confusion_matrix

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

from src.evaluator_encoder_only import evaluate
from src.constants import NUM_TO_LETTER


def train(config, model, train_loader, val_loader, test_loader, device):
    """
    Train RNA inverse folding model using the specified config and data loaders.

    Args:
        config (dict): wandb configuration dictionary
        model (nn.Module): RNA inverse folding model to be trained
        train_loader (DataLoader): training data loader
        val_loader (DataLoader): validation data loader
        test_loader (DataLoader): test data loader
        device (torch.device): device to train the model on
    """

    # Initialise loss function
    train_loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
    eval_loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.0)

    # Initialise optimizer and scheduler
    lr = config.lr
    optimizer = Adam(model.parameters(), lr)
    scheduler = ReduceLROnPlateau(
        optimizer, mode="max", factor=0.9, patience=1, min_lr=0.00001
    )

    if device.type == "xpu":
        import intel_extension_for_pytorch as ipex

        model, optimizer = ipex.optimize(model, optimizer=optimizer)

    # Initialise lookup table mapping integers to nucleotides
    lookup = train_loader.dataset.featurizer.num_to_letter

    # Initialise best checkpoint information
    best_epoch, best_val_loss, best_val_acc = -1, np.inf, 0

    ##################################
    # Training loop over mini-batches
    ##################################

    for epoch in range(config.epochs):

        # Training iteration
        model.train()
        train_loss, train_acc, train_confusion = loop(
            model, train_loader, train_loss_fn, optimizer, device
        )
        print_and_log(
            epoch,
            train_loss,
            train_acc,
            train_confusion,
            lr=lr,
            mode="train",
            lookup=lookup,
        )

        if epoch % config.val_every == 0 or epoch == config.epochs - 1:

            model.eval()
            with torch.no_grad():

                # Evaluate on validation set
                val_loss, val_acc, val_confusion = loop(
                    model, val_loader, eval_loss_fn, None, device
                )
                print_and_log(
                    epoch, val_loss, val_acc, val_confusion, mode="val", lookup=lookup
                )

                # LR scheduler step
                scheduler.step(val_acc)
                lr = optimizer.param_groups[0]["lr"]

                if val_acc > best_val_acc:
                    # Update best checkpoint
                    best_epoch, best_val_loss, best_val_acc = epoch, val_loss, val_acc

                    # Evaluate on test set
                    test_loss, test_acc, test_confusion = loop(
                        model, test_loader, eval_loss_fn, None, device
                    )
                    print_and_log(
                        epoch,
                        test_loss,
                        test_acc,
                        test_confusion,
                        mode="test",
                        lookup=lookup,
                    )

                    # Update wandb summary metrics
                    wandb.run.summary["best_epoch"] = best_epoch
                    wandb.run.summary["best_val_perp"] = np.exp(best_val_loss)
                    wandb.run.summary["best_val_acc"] = best_val_acc
                    wandb.run.summary["best_test_perp"] = np.exp(test_loss)
                    wandb.run.summary["best_test_acc"] = test_acc

                    if config.save:
                        # Save best checkpoint
                        checkpoint_path = os.path.join(
                            wandb.run.dir, "best_checkpoint.h5"
                        )
                        torch.save(model.state_dict(), checkpoint_path)
                        wandb.run.summary["best_checkpoint"] = checkpoint_path

                        ## update log
                        with open(
                            os.path.join(wandb.run.dir, "best_checkpoint_log.txt"), "a"
                        ) as f:
                            f.write(
                                f"\n\nbest_epoch:\t{best_epoch}\n"
                                + f'best_val_perp:\t{wandb.run.summary["best_val_perp"]}\n'
                                + f'best_val_acc:\t{wandb.run.summary["best_val_acc"]}\n'
                                + f'best_test_perp:\t{wandb.run.summary["best_test_perp"]}\n'
                                + f'best_test_acc:\t{wandb.run.summary["best_test_acc"]}\n'
                                + f'best_checkpoint_path:\t{wandb.run.summary["best_checkpoint"]}\n'
                            )

        if config.save:
            # Save current epoch checkpoint
            torch.save(
                model.state_dict(), os.path.join(wandb.run.dir, "current_checkpoint.h5")
            )

    # End of training
    if config.save:
        # Evaluate best checkpoint
        print(
            f"EVALUATION: loading {os.path.join(wandb.run.dir, 'best_checkpoint.h5')} (epoch {best_epoch})"
        )
        model.load_state_dict(
            torch.load(os.path.join(wandb.run.dir, "best_checkpoint.h5"))
        )

        for loader, set_name in [(test_loader, "test"), (val_loader, "val")]:
            # Run evaluator
            (
                df,
                samples_list,
                recovery_list,
                perplexity_list,
                scscore_list,
                scscore_ribonanza_list,
                scscore_rmsd_list,
                scscore_tm_list,
                scscore_gdt_list,
            ) = evaluate(
                model,
                loader.dataset,
                config.n_samples,
                config.temperature,
                device,
                model_name=set_name,
                metrics=[
                    "recovery",
                    "perplexity",
                    "sc_score_eternafold",
                    "sc_score_ribonanzanet",
                    "sc_score_rhofold",
                ],
                save_structures=True,
            )
            # Update wandb summary metrics
            wandb.run.summary[f"best_{set_name}_recovery"] = np.mean(recovery_list)
            wandb.run.summary[f"best_{set_name}_perplexity"] = np.mean(perplexity_list)
            wandb.run.summary[f"best_{set_name}_scscore"] = np.mean(scscore_list)
            wandb.run.summary[f"best_{set_name}_scscore_ribonanza"] = np.mean(
                scscore_ribonanza_list
            )
            wandb.run.summary[f"best_{set_name}_scscore_rmsd"] = np.mean(
                scscore_rmsd_list
            )
            wandb.run.summary[f"best_{set_name}_scscore_tm"] = np.mean(scscore_tm_list)
            wandb.run.summary[f"best_{set_name}_scscore_gdt"] = np.mean(
                scscore_gdt_list
            )
            print(
                f"BEST {set_name} recovery: {np.mean(recovery_list):.4f}\
                    perplexity: {np.mean(perplexity_list):.4f}\
                    scscore: {np.mean(scscore_list):.4f}\
                    scscore_ribonanza: {np.mean(scscore_ribonanza_list):.4f}\
                    scscore_rmsd: {np.mean(scscore_rmsd_list):.4f}\
                    scscore_tm: {np.mean(scscore_tm_list):.4f}\
                    scscore_gdt: {np.mean(scscore_gdt_list):.4f}"
            )
            # Save results
            torch.save(
                (
                    df,
                    samples_list,
                    recovery_list,
                    perplexity_list,
                    scscore_list,
                    scscore_ribonanza_list,
                    scscore_rmsd_list,
                    scscore_tm_list,
                    scscore_gdt_list,
                ),
                os.path.join(wandb.run.dir, f"{set_name}_results.pt"),
            )


def loop(model, dataloader, loss_fn, optimizer=None, device="cpu"):
    """
    Training loop for a single epoch over the data loader.

    Args:
        model (nn.Module): RNA inverse folding model
        dataloader (DataLoader): data loader for the current epoch
        loss_fn (nn.Module): loss function to compute the loss
        optimizer (torch.optim): optimizer to update model parameters
        device (torch.device): device to train the model on

    Note:
        This function is used for both training and evaluation loops.
        Not passing an optimizer will run the model in evaluation mode.

    Returns:
        float: average loss over the epoch
        float: average accuracy over the epoch
        np.ndarray: confusion matrix over the epoch
    """

    confusion = np.zeros((model.out_dim, model.out_dim))
    total_loss, total_correct, total_count = 0, 0, 0

    t = tqdm(dataloader)
    for batch in t:
        if optimizer:
            optimizer.zero_grad()

        # move batch to device
        batch = batch.to(device)

        try:
            logits = model(batch)
        except RuntimeError as e:
            if "CUDA out of memory" not in str(e):
                raise (e)
            print("Skipped batch due to OOM", flush=True)
            for p in model.parameters():
                if p.grad is not None:
                    del p.grad  # free some memory
            torch.cuda.empty_cache()
            continue

        # compute loss
        loss_value = loss_fn(logits, batch.seq)

        if optimizer:
            # backpropagate loss and update parameters
            loss_value.backward()
            optimizer.step()

        # update metrics
        num_nodes = int(batch.seq.size(0))
        total_loss += float(loss_value.item()) * num_nodes
        total_count += num_nodes
        pred = torch.argmax(logits, dim=-1).detach().cpu().numpy()
        true = batch.seq.detach().cpu().numpy()
        total_correct += (pred == true).sum()
        confusion += confusion_matrix(true, pred, labels=range(model.out_dim))

        t.set_description("%.5f" % float(total_loss / total_count))

    return total_loss / total_count, total_correct / total_count, confusion


def print_and_log(
    epoch,
    loss,
    acc,
    confusion,
    recovery=None,
    lr=None,
    mode="train",
    lookup=NUM_TO_LETTER,  # reverse of {'A': 0, 'G': 1, 'C': 2, 'U': 3}
):
    # Create log string and wandb metrics dict
    log_str = f"\nEPOCH {epoch} {mode.upper()} loss: {loss:.4f} perp: {np.exp(loss):.4f} acc: {acc:.4f}"
    wandb_metrics = {
        f"{mode}/loss": loss,
        f"{mode}/perp": np.exp(loss),
        f"{mode}/acc": acc,
        "epoch": epoch,
    }

    if lr is not None:
        # Add learning rate to loggers
        log_str += f" lr: {lr:.6f}"
        wandb_metrics[f"lr"] = lr

    if recovery is not None:
        # Add mean sequence recovery to loggers
        log_str += f" rec: {np.mean(recovery):.4f}"
        wandb_metrics[f"{mode}/recovery"] = np.mean(recovery)

    print(log_str)
    print_confusion(confusion, lookup=lookup)
    wandb.log(wandb_metrics)


def print_confusion(mat, lookup):
    counts = mat.astype(np.int32)
    mat = (counts.T / counts.sum(axis=-1, keepdims=True).T).T
    mat = np.round(mat * 1000).astype(np.int32)
    res = "\n"
    for i in range(len(lookup.keys())):
        res += "\t{}".format(lookup[i])
    res += "\tCount\n"
    for i in range(len(lookup.keys())):
        res += "{}\t".format(lookup[i])
        res += "\t".join("{}".format(n) for n in mat[i])
        res += "\t{}\n".format(sum(counts[i]))
    print(res)
