from __future__ import print_function

import os
import sys
import time

from runner_helper import *

sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path
import numpy as np
import torch
import torch.optim.lr_scheduler as lr_scheduler
import wandb
from create_output_dir import get_output_dir

# from get_spn_class_log import SPNModel
from loguru import logger
from matplotlib import pyplot as plt
from nn_scripts import get_ll_scores, train, train_and_validate_single_example, validate
from project_utils.logging_utils import save_args_as_yaml
from project_utils.plotting import plot_images
from project_utils.results import construct_data_row, construct_header, write_to_csv


def test_and_process_outputs(
    cfg,
    device,
    fabric,
    pgm,
    model_dir,
    model_outputs_dir,
    train_loader,
    test_loader,
    val_loader,
    mpe_solutions,
    model,
    library_pgm,
    optimizer,
    best_loss,
    counter,
    all_train_losses,
    all_val_losses,
    best_model_info,
    num_data_features,
    num_pgm_feature,
    num_outputs,
    num_query_variables,
):
    # Load the best model for further use or testing
    logger.info("Loading best model...")
    model.load_state_dict(best_model_info["model_state"])
    os.makedirs(model_outputs_dir, exist_ok=True)
    if not cfg.debug_tuning and not cfg.no_train and cfg.evaluate_training_set:
        dataset_types = ["train", "val", "test"]
    else:
        dataset_types = ["test"]
    ll_scores_nn = {data_type: None for data_type in dataset_types}
    ll_scores_pgm = {data_type: None for data_type in dataset_types}
    save_args_as_yaml(cfg, os.path.join(model_outputs_dir, "cfg.yaml"))

    for data_type in dataset_types:
        if data_type == "train":
            loader = train_loader
        elif data_type == "test":
            loader = test_loader
        elif data_type == "val":
            loader = val_loader

        if not cfg.train_on_test_set or (data_type in ["train", "val"]):
            validation_function = validate
            (
                _,
                _,
                _,
                all_unprocessed_data,
                all_nn_outputs,
                all_outputs_for_pgm,
                all_buckets,
            ) = validation_function(cfg, model, pgm, device, loader, best_loss, counter)
        else:
            if cfg.use_batch_train_on_test:
                validation_function = train_and_validate_batch
            else:
                validation_function = train_and_validate_single_example

            (
                _,
                _,
                _,
                all_unprocessed_data,
                all_nn_outputs,
                all_outputs_for_pgm,
                all_buckets,
            ) = validation_function(
                cfg,
                model,
                pgm,
                fabric,
                loader,
                optimizer,
                best_loss,
                counter,
                library_pgm,
                device,
                num_data_features,
                num_pgm_feature,
                num_outputs,
                num_query_variables,
            )
            # logger.info(f"Outputs of the model - {all_outputs_for_pgm}")
        if mpe_solutions is not None:
            mpe_output = mpe_solutions[f"{data_type}_mpe_output"]
            root_ll_pgm = mpe_solutions[f"{data_type}_root_ll_pgm"]
        else:
            mpe_output = None
            root_ll_pgm = None
        root_ll_nn_torch, root_ll_nn_library, root_ll_pgm, mpe_output = get_ll_scores(
            cfg,
            library_pgm,
            pgm,
            all_unprocessed_data,
            all_outputs_for_pgm,
            all_buckets,
            mpe_output,
            root_ll_pgm,
            device,
        )
        mean_ll_nn = np.mean(root_ll_nn_library)
        mean_ll_pgm = np.mean(root_ll_pgm)
        mean_ll_nn_ours = np.mean(root_ll_nn_torch)
        logger.info(f"Root LL NN Ours {data_type}: {mean_ll_nn_ours}")
        logger.info(f"Root LL NN {data_type}: {mean_ll_nn}")
        logger.info(f"Root LL PGM {data_type}: {mean_ll_pgm}")
        ll_scores_nn[data_type] = mean_ll_nn_ours
        ll_scores_pgm[data_type] = mean_ll_pgm

        wandb.log(
            {
                f"Root LL NN {data_type}": mean_ll_nn_ours,
                f"Root LL PGM {data_type}": mean_ll_pgm,
            }
        )
        os.makedirs(model_outputs_dir, exist_ok=True)
        output_path = f"{model_outputs_dir}/{data_type}_outputs.npz"
        np.savez(
            output_path,
            all_unprocessed_data=all_unprocessed_data,
            all_outputs_for_pgm=all_outputs_for_pgm,
            all_buckets=all_buckets,
            all_nn_outputs=all_nn_outputs,
            mpe_output=mpe_output,
            root_ll_nn_torch=root_ll_nn_torch,
            root_ll_pgm=root_ll_pgm,
            root_ll_nn=root_ll_nn_library,
            mean_ll_nn=mean_ll_nn_ours,
            mean_ll_nn_library=mean_ll_nn,
            mean_ll_pgm=mean_ll_pgm,
            std_ll_nn_library=np.std(root_ll_nn_library),
            std_ll_nn=np.std(root_ll_nn_torch),
            std_ll_pgm=np.std(root_ll_pgm),
            runtime=time.time() - wandb.run.start_time,
        )
        # Save the metrics to a common file for models and datasets
        if data_type == "test":
            metrics = {
                "mean_ll_nn": mean_ll_nn_ours,
                "mean_ll_pgm": mean_ll_pgm,
                "std_ll_nn": np.std(root_ll_nn_library),
                "std_ll_pgm": np.std(root_ll_pgm),
            }
            model_name = cfg.model
            dataset_name = cfg.dataset
            # get wandb run id
            run_id = wandb.run.id
            outputs_path = f"{model_outputs_dir}/{data_type}_outputs.npz"
            runtime = time.time() - wandb.run.start_time
            # Save all these detils in a csv file
            common_results_path = os.path.join(
                "common_results", dataset_name, model_name
            )
            os.makedirs(common_results_path, exist_ok=True)
            csv_path = os.path.join(common_results_path, "results.csv")
            header = construct_header(cfg)
            data_row = construct_data_row(run_id, metrics, outputs_path, runtime, cfg)
            write_to_csv(csv_path, header, data_row)

        cfg.outputs_path = f"{model_outputs_dir}/{data_type}_outputs.npz"

    alert_str = ""
    for data_type in ll_scores_nn:
        alert_str += f"{model_dir} \n Root LL NN {data_type}: {ll_scores_nn[data_type]}, Root LL PGM {data_type}: {ll_scores_pgm[data_type]}, \n"
    wandb.alert(
        title=f"Dataset: {dataset_name}, {cfg.task}",
        text=alert_str,
    )
    logger.info(alert_str)
    if not cfg.no_train:
        # Create x-axis values (epochs or iterations)
        epochs = range(1, len(all_train_losses) + 1)
        plt.figure()
        # Plot train and test loss on the same plot
        plt.plot(epochs, all_train_losses, label="Train Loss", color="blue")
        epochs = range(len(all_val_losses))

        plt.plot(epochs, all_val_losses, label="Test Loss", color="red")

        # Add labels and title
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.title("Train and Test Loss over Epochs")

        # Add legend
        plt.legend()
        plt.savefig(os.path.join(model_dir, "loss_plot.png"))
        # Show the plot
        plt.show()


def select_lr_scheduler(
    cfg, scheduler_name, optimizer, train_loader=None, *args, **kwargs
):
    """
    Selects and returns a learning rate scheduler from PyTorch's available lr_scheduler.

    cfg:
        scheduler_name (str): Name of the scheduler to create.
        optimizer (torch.optim.Optimizer): Optimizer for the scheduler.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        torch.optim.lr_scheduler: The instantiated learning rate scheduler.

    Raises:
        ValueError: If the scheduler name is not recognized.
    """
    if scheduler_name == "StepLR":
        return torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=cfg.epochs // 5, gamma=0.8
        )

    elif scheduler_name == "ReduceLROnPlateau":
        return torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=10, verbose=True, factor=0.8
        )
    elif scheduler_name == "OneCycleLR":
        steps_per_epoch = len(train_loader)  # Number of batches in one epoch
        return torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=0.01,
            epochs=cfg.epochs,
            steps_per_epoch=steps_per_epoch,
            anneal_strategy="cos",  # Can be 'cos' for cosine annealing
            div_factor=25,  # Factor to divide max_lr to get the lower boundary of the learning rate
            final_div_factor=1e4,  # Factor to reduce the learning rate at the end of the cycle
            verbose=True,
        )
    elif scheduler_name == "MultiStepLR":
        # MultiStepLR: Similar to StepLR but allows for multiple steps at specified epochs.
        # Typical cfg: milestones (list of int), gamma (float)
        # Example: select_lr_scheduler('MultiStepLR', optimizer, milestones=[30, 80], gamma=0.1)
        return lr_scheduler.MultiStepLR(optimizer, *args, **kwargs)

    elif scheduler_name == "ExponentialLR":
        # ExponentialLR: Decays the learning rate of each parameter group by gamma every epoch.
        # Typical cfg: gamma (float)
        # Example: select_lr_scheduler('ExponentialLR', optimizer, gamma=0.95)
        return lr_scheduler.ExponentialLR(optimizer, *args, **kwargs)

    elif scheduler_name == "CosineAnnealingLR":
        # CosineAnnealingLR: Adjusts the learning rate based on a cosine curve.
        # Typical cfg: T_max (int), eta_min (float, optional)
        # Example: select_lr_scheduler('CosineAnnealingLR', optimizer, T_max=50, eta_min=0)
        return lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=cfg.epochs, *args, **kwargs
        )

    elif scheduler_name == "CyclicLR":
        # CyclicLR: Cycles the learning rate between two boundaries with a certain strategy.
        # Typical cfg: base_lr (float), max_lr (float), step_size_up (int), mode ('triangular', 'triangular2', 'exp_range')
        # Example: select_lr_scheduler('CyclicLR', optimizer, base_lr=0.001, max_lr=0.01, step_size_up=20, mode='triangular')
        return lr_scheduler.CyclicLR(optimizer, *args, **kwargs)

    elif scheduler_name == "CosineAnnealingWarmRestarts":
        # CosineAnnealingWarmRestarts: Uses a cosine annealing schedule with warm restarts.
        # Typical cfg: T_0 (int), T_mult (int, optional), eta_min (float, optional)
        # Example: select_lr_scheduler('CosineAnnealingWarmRestarts', optimizer, T_0=10, T_mult=2, eta_min=0.001)
        return lr_scheduler.CosineAnnealingWarmRestarts(optimizer, *args, **kwargs)

    elif scheduler_name == "None":
        return None
    else:
        raise ValueError(f"Unrecognized scheduler name: {scheduler_name}")


def init_directories(cfg, project_name):
    main_dir_path, subdirectories = get_output_dir(cfg, project_name)
    model_dir = subdirectories["models"]
    model_outputs_dir = subdirectories["outputs"]
    logger.add(
        os.path.join(main_dir_path, "logs.log"), format="{time} {level} {message}"
    )
    logger.info(f"Output directory: {model_dir}")
    # Define the output directory
    if not os.path.exists(model_dir) and not cfg.no_train:
        # If the folder does not exist, create it
        os.makedirs(model_dir)
        print("Folder created successfully!")
    else:
        print("Folder already exists!")
    cfg.model_dir = model_dir
    cfg.model_outputs_dir = model_outputs_dir


def init_debug(cfg):
    cfg.debug = not cfg.no_debug
