import os
import time
import gc
import torch
import numpy as np
from dgl.dataloading import GraphDataLoader

from src.datasets.base_graph import BaseGraphDataset
from src.datasets.multi_graph import MultiGraphDataset
from src.train.buffer import ErrorBuffer
from src.utils.metrics import relative_rmse_field
from src.utils.utils import relative_error, save_fields

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def unscale_with_y_scaler(y_scaler, t: torch.Tensor) -> torch.Tensor:
    t_np = t.detach().cpu().numpy()
    inv_np = y_scaler.inverse_transform(t_np)
    return torch.from_numpy(inv_np).to(t.device)


def evaluate_mise_gnn(
    args,
    model,
    checkpoint_path,
    train_data,
    test_data,
    train_indices,
    test_indices,
    save_dir,
    base_dataset_is_tree=False,
    node_builder_name="airfoil_geom_nodes",
    edge_builder_name="airfoil_geom_edges"
):

    t_total_start = time.time()

    # ==== Load checkpoint ====
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    print(f"✅ Loaded checkpoint from: {checkpoint_path}")

    # ==== Prepare dataset ====
    train_base_dataset = BaseGraphDataset(
        args, train_data, data_type="train",
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )

    in_scaler = train_base_dataset.in_scaler
    out_scaler = train_base_dataset.out_scaler
    y_scaler = train_base_dataset.y_scaler

    print("Y scaler:", y_scaler.summary())
    y_all = np.vstack(train_data["Y_fields"])
    print("RT err (Y): ", y_scaler.roundtrip_error(y_all))

    test_base_dataset = BaseGraphDataset(
        args, test_data, data_type="test",
        in_scaler=in_scaler, out_scaler=out_scaler, y_scaler=y_scaler,
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )

    test_multigraph_dataset = MultiGraphDataset(
        args=args,
        base_dataset=test_base_dataset,
        indices=test_indices,
        epoch=None,
        init_refine=True,
        stock_graph=False,
        edge_builder_name=edge_builder_name
    )

    test_dataloader = GraphDataLoader(
        test_multigraph_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        pin_memory=True
    )

    # ==== Error caches ====
    test_error_buffer  = ErrorBuffer(num_samples=len(test_base_dataset))

    # ==== Inference Phase 1 ====
    t_phase1_start = time.time()

    all_error_preds_base = []
    all_error_trues_base = []

    all_field_preds_base = []
    all_field_trues_base = []

    with torch.no_grad():
        for base_graph, new_graph, _, _, sample_ids in test_dataloader:
            base_graph = base_graph.to(device)

            # Predict on BASE GRAPH
            field_pred_base, error_pred_base = model(
                base_graph.ndata["x"], base_graph.edata["f"], base_graph
            )

            num_nodes_per_graph = base_graph.batch_num_nodes().tolist()
            err_splits = list(torch.split(error_pred_base.detach().cpu(), num_nodes_per_graph, dim=0))
            sid_list = sample_ids.tolist() if torch.is_tensor(sample_ids) else list(sample_ids)
            test_error_buffer.update_batch(sid_list, err_splits)

            # True field
            field_true_base_real = unscale_with_y_scaler(y_scaler, base_graph.ndata["y"])
            field_pred_base_real = unscale_with_y_scaler(y_scaler, field_pred_base)

            # Error prediction
            error_true_base, _ = relative_error(field_pred_base_real, field_true_base_real)
            if error_true_base.ndim == 1:
                error_true_base = error_true_base.unsqueeze(-1).detach()

            # Stock all errors on base graph
            all_error_preds_base.append(error_pred_base.detach().cpu())
            all_error_trues_base.append(error_true_base.detach().cpu())

            all_field_preds_base.append(field_pred_base_real.detach().cpu())
            all_field_trues_base.append(field_true_base_real.detach().cpu())

    t_phase1 = time.time() - t_phase1_start
    print(f"⏱️ Phase 1 (base graph error prediction) took {t_phase1:.3f}s \n")

    # ==== Inference Phase 2 ====
    t_phase2_start = time.time()

    del test_multigraph_dataset, test_dataloader
    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    test_multigraph_dataset = MultiGraphDataset(
        args=args,
        base_dataset=test_base_dataset,
        indices=test_indices,
        epoch="inference",
        init_refine=False,
        stock_graph=True,
        error_maps=test_error_buffer.get_all(),
        edge_builder_name=edge_builder_name
    )

    test_dataloader = GraphDataLoader(
        test_multigraph_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        pin_memory=True
    )

    t_phase2 = time.time() - t_phase2_start
    print(f"⏲️ Phase 2 (update edges / rebuild new graphs) took {t_phase2:.3f}s \n")

    # ==== Inference Phase 3 ====
    t_phase3_start = time.time()

    all_field_preds_new = []
    all_field_trues_new = []

    with torch.no_grad():
        for base_graph, new_graph, _, _, sample_ids in test_dataloader:
            new_graph  = new_graph.to(device)

            field_pred_new, _ = model(
                new_graph.ndata["x"], new_graph.edata["f"], new_graph
            )

            field_true_new_real = unscale_with_y_scaler(y_scaler, new_graph.ndata["y"])
            field_pred_new_real = unscale_with_y_scaler(y_scaler, field_pred_new)

            all_field_preds_new.append(field_pred_new_real.detach().cpu())
            all_field_trues_new.append(field_true_new_real.detach().cpu())

    t_phase3 = time.time() - t_phase3_start
    t_total = time.time() - t_total_start
    print(f"⏱️ Phase 3 (new graph field prediction) took {t_phase3:.3f}s")
    print(f"⏰ Total inference took {t_total:.3f}s \n")

    # ==== Metrics ====
    relative_rmses_base = relative_rmse_field(all_field_trues_base, all_field_preds_base)
    mean_relative_rmse_base = relative_rmses_base.mean().item()
    print(f"📊 RRMSE (BASE GRAPH): {mean_relative_rmse_base}")

    relative_rmses_new = relative_rmse_field(all_field_trues_new, all_field_preds_new)
    mean_relative_rmse_new = relative_rmses_new.mean().item()
    print(f"📊 RRMSE (NEW GRAPH): {mean_relative_rmse_new}")


    # ==== Save outputs ====
    os.makedirs(save_dir, exist_ok=True)

    save_fields(os.path.join(save_dir, "field_pred_new.h5"), all_field_preds_new)
    save_fields(os.path.join(save_dir, "field_true_new.h5"), all_field_trues_new)

    save_fields(os.path.join(save_dir, "field_pred_base.h5"), all_field_preds_base)
    save_fields(os.path.join(save_dir, "field_true_base.h5"), all_field_trues_base)

    save_fields(os.path.join(save_dir, "error_pred_base.h5"), all_error_preds_base)
    save_fields(os.path.join(save_dir, "error_true_base.h5"), all_error_trues_base)

    print("✅ Saved all test predictions and errors for plotting!")
