import os
import time
import torch
import numpy as np
import pandas as pd

from dgl.dataloading import GraphDataLoader
from src.datasets.base_graph import BaseGraphDataset
from src.utils.seed import seed_everything
from src.utils.utils import save_fields
from src.utils.metrics import relative_rmse_field

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i in range(torch.cuda.device_count()):
    print(f"💻 Using device {i}: {torch.cuda.get_device_properties(i).name}")


def train_base_gnn(
    args, model, optimizer, loss_fn, train_data, test_data,
    base_dataset_is_tree=False,
    node_builder_name="basic_nodes",
    edge_builder_name="basic_edges"
):
    seed_everything(args.seed)
    model.to(device)

    # ==== Dirs ====
    checkpoint_dir  = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/checkpoints")
    predictions_dir = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/predictions")
    metrics_dir     = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/metrics")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(predictions_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)

    # ==== Base datasets ====
    train_base_dataset = BaseGraphDataset(
        args, train_data, data_type="train",
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )
    test_base_dataset  = BaseGraphDataset(
        args, test_data, data_type="test",
        in_scaler=train_base_dataset.in_scaler,
        out_scaler=train_base_dataset.out_scaler,
        y_scaler=train_base_dataset.y_scaler,
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )

    print("Y scaler:",   train_base_dataset.y_scaler.summary())

    y_all = np.vstack(train_data["Y_fields"])
    print("RT err (Y): ", train_base_dataset.y_scaler.roundtrip_error(y_all))

    y_scaler = train_base_dataset.y_scaler

    # ==== DataLoaders ====
    train_dataloader = GraphDataLoader(
        train_base_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=True
    )

    test_dataloader = GraphDataLoader(
        test_base_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        pin_memory=True
    )

    # Calculate and print model parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"⚙️ Total number of model parameters: {total_params}")

    # Record the start time
    start_time = time.time()

    # Data structure for storing metrics
    metrics = []

    # Variable to store the best relative RMSE mean
    best_relative_rmse_mean = float('inf')
    best_epoch = -1

    # Training loop
    num_epochs = args.num_epochs

    for epoch in range(num_epochs):
        epoch_start_time = time.time()

        train_loss = 0.0
        model.train()

        for _, (graph, _, _, _) in enumerate(train_dataloader):
            optimizer.zero_grad()

            graph = graph.to(device)
            pred = model(graph.ndata["x"], graph.edata["f"], graph)
            loss = loss_fn(graph.ndata["y"], pred)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= max(1, len(train_dataloader))

        test_loss = 0.0
        model.eval()
        y_tests, y_test_preds = [], []

        with torch.no_grad():
            for _, (graph, _, _, _) in enumerate(test_dataloader):
                graph = graph.to(device)
                pred = model(graph.ndata["x"], graph.edata["f"], graph)

                loss = loss_fn(graph.ndata["y"], pred)
                test_loss += loss.item()

                field_true_real = unscale_with_y_scaler(y_scaler, graph.ndata["y"])
                field_pred_real = unscale_with_y_scaler(y_scaler, pred)

                y_test_preds.append(field_pred_real.detach().cpu())
                y_tests.append(field_true_real.detach().cpu())

            test_loss /= max(1, len(test_dataloader))

        if (epoch+1) % 500 == 0:
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"state_epoch_{epoch+1}.pt"))
            save_fields(os.path.join(predictions_dir, f"predicted_fields_{epoch+1}.h5"), y_test_preds)

        relative_rmses = relative_rmse_field(y_tests, y_test_preds)
        mean_relative_rmse = relative_rmses.mean().item()

        if mean_relative_rmse < best_relative_rmse_mean:
            best_relative_rmse_mean = mean_relative_rmse
            best_epoch = epoch+1
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_state.pt"))
            save_fields(os.path.join(predictions_dir, "best_predicted_fields.h5"), y_test_preds)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time

        # Collect metrics for this epoch
        metrics.append({
            "epoch": epoch,
            "train_loss": train_loss,
            "test_loss": test_loss,
            "mean_rrmse": mean_relative_rmse,
            "rrmses_field": relative_rmses.cpu(),
            "duration": epoch_duration
        })
        metrics_str = (f"🌟"
            f"Epoch {epoch+1} | "
            f"Train Loss: {train_loss:.7f} | Test Loss: {test_loss:.7f} | "
            f"Mean RRMSE: {mean_relative_rmse:.7f} | "
            f"RRMSE Field: {[f'{v:.7f}' for v in relative_rmses]} | "
            f"Duration: {epoch_duration:.2f} (s) | "
            f"Best Mean RRMSE: {best_relative_rmse_mean:.7f} at epoch {best_epoch}"
        )
        print(metrics_str)

    # Saving collected metrics to a CSV file
    metrics_df = pd.DataFrame(metrics)
    metrics_df.to_csv(os.path.join(metrics_dir, "metrics.csv"), index=False)

    # Record the end time
    end_time = time.time()

    # Calculate the training duration
    training_duration = end_time - start_time
    print(f"⏰ Training took {training_duration:.7f} seconds")


def unscale_with_y_scaler(y_scaler, t: torch.Tensor) -> torch.Tensor:
    t_np = t.detach().cpu().numpy()
    inv_np = y_scaler.inverse_transform(t_np)  # CompositeScaler (NumPy)
    return torch.from_numpy(inv_np).to(t.device)
