import os
import gc
import time
import torch
import numpy as np
import pandas as pd

from dgl.dataloading import GraphDataLoader
from src.datasets.base_graph import BaseGraphDataset
from src.datasets.multi_graph import MultiGraphDataset
from src.utils.seed import seed_everything
from src.utils.utils import save_fields, relative_error
from src.utils.metrics import relative_rmse_field
from .buffer import ErrorBuffer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i in range(torch.cuda.device_count()):
    print(f"💻 Using device {i}: {torch.cuda.get_device_properties(i).name}")


def train_mise_gnn(
    args,
    model,
    optimizer,
    loss_function,
    train_data,
    test_data,
    train_indices,
    test_indices,
    base_dataset_is_tree=False,
    node_builder_name="basic_nodes",
    edge_builder_name="basic_edges",
):
    seed_everything(args.seed)
    model.to(device)

    # ==== Dirs ====
    checkpoint_dir  = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/checkpoints")
    predictions_dir = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/predictions")
    metrics_dir     = os.path.join(args.save_path, f"{args.dataset_name}/{args.run_name}/metrics")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(predictions_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)

    # ==== Base datasets ====
    train_base_dataset = BaseGraphDataset(
        args, train_data, data_type="train",
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )
    test_base_dataset  = BaseGraphDataset(
        args, test_data, data_type="test",
        in_scaler=train_base_dataset.in_scaler,
        out_scaler=train_base_dataset.out_scaler,
        y_scaler=train_base_dataset.y_scaler,
        node_builder_name=node_builder_name, edge_builder_name=edge_builder_name, use_tree_algo = base_dataset_is_tree
    )

    print("Y scaler:",   train_base_dataset.y_scaler.summary())

    y_all = np.vstack(train_data["Y_fields"])
    print("RT err (Y): ", train_base_dataset.y_scaler.roundtrip_error(y_all))

    y_scaler = train_base_dataset.y_scaler

    # ==== MultiGraph ====
    train_multigraph_dataset = MultiGraphDataset(
        args=args, base_dataset=train_base_dataset, indices=train_indices,
        epoch=None, init_refine=True, stock_graph=False,
        edge_builder_name=edge_builder_name
    )
    test_multigraph_dataset = MultiGraphDataset(
        args=args, base_dataset=test_base_dataset, indices=test_indices,
        epoch=None, init_refine=True, stock_graph=False,
        edge_builder_name=edge_builder_name
    )

    # ==== DataLoaders ====
    train_dataloader = GraphDataLoader(
        train_multigraph_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=True
    )
    test_dataloader = GraphDataLoader(
        test_multigraph_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        pin_memory=True
    )

    # ==== Stats ====
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"⚙️ Total trainable parameters: {total_trainable_params}")

    start_time = time.time()
    metrics = []
    best_relative_rmse_mean = float('inf')
    best_epoch = -1
    total_edge_update_time = 0.0

    # ==== Error caches ====
    train_error_buffer = ErrorBuffer(num_samples=len(train_base_dataset))
    test_error_buffer  = ErrorBuffer(num_samples=len(test_base_dataset))

    # AMP
    use_amp = getattr(args, "use_amp", False) and (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    # return latent
    wants_latent = getattr(args, "return_latent", False)

    num_epochs = args.num_epochs
    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_start_time = time.time()
        total_train_loss = 0.0

        for base_graph, new_graph, _, _, sample_ids in train_dataloader:
            base_graph = base_graph.to(device)
            new_graph  = new_graph.to(device)
            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast("cuda", enabled=use_amp):
                # ==== BASE GRAPH ====
                field_pred_base, error_pred_base = model(
                    base_graph.ndata["x"], base_graph.edata["f"], base_graph
                )

                field_true_base_real = unscale_with_y_scaler(y_scaler, base_graph.ndata["y"])
                field_pred_base_real = unscale_with_y_scaler(y_scaler, field_pred_base)

                error_true_base, _ = relative_error(field_pred_base_real, field_true_base_real)
                if error_true_base.ndim == 1:
                    error_true_base = error_true_base.unsqueeze(-1).detach()

                loss_error_head = loss_function(error_pred_base, error_true_base)

                # ==== NEW GRAPH ====
                field_pred_new, _ = model(
                    new_graph.ndata["x"], new_graph.edata["f"], new_graph
                )
                loss_field_head = loss_function(field_pred_new, new_graph.ndata["y"])

                total_loss = loss_field_head + loss_error_head

            if use_amp:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()

            total_train_loss += total_loss.item()

            # ==== Update error cache ====
            num_nodes_per_graph = base_graph.batch_num_nodes().tolist()
            err_splits = list(torch.split(error_pred_base.detach().cpu(), num_nodes_per_graph, dim=0))
            if torch.is_tensor(sample_ids):
                sid_list = sample_ids.tolist()
            else:
                sid_list = list(sample_ids)
            train_error_buffer.update_batch(sid_list, err_splits)

        average_train_loss = total_train_loss / max(1, len(train_dataloader))

        # ===== Validation & cache error test =====
        model.eval()
        total_val_loss = 0.0
        list_field_pred_test, list_field_true_test = [], []
        all_node_enc, all_edge_enc, all_pro_enc = [], [], []

        with torch.no_grad():
            for base_graph, new_graph, _, _, sample_ids in test_dataloader:
                base_graph = base_graph.to(device)
                new_graph  = new_graph.to(device)

                if wants_latent:
                    field_pred_base, error_pred_base, node_latent_base, edge_latent_base, processor_latent_base = model(
                        base_graph.ndata["x"], base_graph.edata["f"], base_graph
                    )
                    all_node_enc.append(node_latent_base.cpu().numpy())
                    all_edge_enc.append(edge_latent_base.cpu().numpy())
                    all_pro_enc.append(processor_latent_base.cpu().numpy())
                else:
                    field_pred_base, error_pred_base = model(
                        base_graph.ndata["x"], base_graph.edata["f"], base_graph
                    )

                num_nodes_per_graph = base_graph.batch_num_nodes().tolist()
                err_splits = list(torch.split(error_pred_base.detach().cpu(), num_nodes_per_graph, dim=0))
                sid_list = sample_ids.tolist() if torch.is_tensor(sample_ids) else list(sample_ids)
                test_error_buffer.update_batch(sid_list, err_splits)

                field_true_base_real = unscale_with_y_scaler(y_scaler, base_graph.ndata["y"])
                field_pred_base_real = unscale_with_y_scaler(y_scaler, field_pred_base)

                error_true_base, _ = relative_error(field_pred_base_real, field_true_base_real)
                if error_true_base.ndim == 1:
                    error_true_base = error_true_base.unsqueeze(-1).detach()

                loss_error_val = loss_function(error_pred_base, error_true_base)

                field_pred_new, _ = model(
                    new_graph.ndata["x"], new_graph.edata["f"], new_graph
                )
                loss_field_val = loss_function(field_pred_new, new_graph.ndata["y"])

                field_true_new_real = unscale_with_y_scaler(y_scaler, new_graph.ndata["y"])
                field_pred_new_real = unscale_with_y_scaler(y_scaler, field_pred_new)

                total_val_loss += (loss_field_val + loss_error_val).item()

                list_field_pred_test.append(field_pred_new_real.detach().cpu())
                list_field_true_test.append(field_true_new_real.detach().cpu())

        average_val_loss = total_val_loss / max(1, len(test_dataloader))

        # if wants_latent and len(all_node_enc) and (epoch % 100 == 0):
        #     node_enc_np = np.concatenate(all_node_enc, axis=0)
        #     edge_enc_np = np.concatenate(all_edge_enc, axis=0)
        #     pro_enc_np  = np.concatenate(all_pro_enc,  axis=0)
        #     np.save(os.path.join(metrics_dir, f"node_latent_epoch{epoch}.npy"), node_enc_np)
        #     np.save(os.path.join(metrics_dir, f"edge_latent_epoch{epoch}.npy"), edge_enc_np)
        #     np.save(os.path.join(metrics_dir, f"pro_latent_epoch{epoch}.npy"),  pro_enc_np)

        if epoch % 200 == 0:
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"state_epoch_{epoch}.pt"))
            save_fields(os.path.join(predictions_dir, f"predicted_fields_{epoch}.h5"), list_field_pred_test)

        relative_rmses = relative_rmse_field(list_field_true_test, list_field_pred_test)
        mean_relative_rmse = relative_rmses.mean().item()

        if mean_relative_rmse < best_relative_rmse_mean:
            best_relative_rmse_mean = mean_relative_rmse
            best_epoch = epoch
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_state.pt"))
            save_fields(os.path.join(predictions_dir, "best_predicted_fields.h5"), list_field_pred_test)

        epoch_duration = time.time() - epoch_start_time
        metrics.append({
            "epoch": epoch,
            "train_loss": average_train_loss,
            "test_loss": average_val_loss,
            "mean_rrmse": mean_relative_rmse,
            "duration": epoch_duration,
        })

        print(
            "🌟 "
            f"Epoch {epoch} | Train Loss: {average_train_loss:.7f} | "
            f"Test Loss: {average_val_loss:.7f} | RRMSE: {mean_relative_rmse:.7f} | "
            f"Duration: {epoch_duration:.2f}s | Best RRMSE: {best_relative_rmse_mean:.7f} @ {best_epoch}"
        )

        # ===== REBUILD EDGES FROM ERROR CACHE =====
        assert all(e is not None for e in train_error_buffer.get_all())
        assert all(e is not None for e in test_error_buffer.get_all())

        if (epoch % args.error_check_interval == 0) and (epoch != num_epochs):
            t0 = time.time()

            del train_multigraph_dataset, train_dataloader
            del test_multigraph_dataset, test_dataloader
            gc.collect()
            if device.type == "cuda":
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()

            train_multigraph_dataset = MultiGraphDataset(
                args=args,
                base_dataset=train_base_dataset,
                indices=train_indices,
                epoch=epoch,
                init_refine=False,
                stock_graph=True,
                error_maps=train_error_buffer.get_all(),
                edge_builder_name=edge_builder_name
            )
            test_multigraph_dataset = MultiGraphDataset(
                args=args,
                base_dataset=test_base_dataset,
                indices=test_indices,
                epoch=epoch,
                init_refine=False,
                stock_graph=False,
                error_maps=test_error_buffer.get_all(),
                edge_builder_name=edge_builder_name
            )

            train_dataloader = GraphDataLoader(
                train_multigraph_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                pin_memory=True
            )
            test_dataloader = GraphDataLoader(
                test_multigraph_dataset,
                batch_size=1,
                shuffle=False,
                drop_last=False,
                pin_memory=True
            )

            dt = time.time() - t0
            total_edge_update_time += dt
            print(f"⏲️ Update edges (from cached errors) took {dt:.3f}s")

        if epoch % 10 == 0:
            metrics_df = pd.DataFrame(metrics)
            metrics_df.to_csv(os.path.join(metrics_dir, "metrics.csv"), index=False)

    total_training_duration = time.time() - start_time
    effective_training_duration = total_training_duration - total_edge_update_time

    print(f"⏰ Total training took {total_training_duration:.3f}s")
    print(f"⏰ Effective training (excluding edge updates) took {effective_training_duration:.3f}s")
    print(f"⏰ Total edge update time: {total_edge_update_time:.3f}s")


def unscale_with_y_scaler(y_scaler, t: torch.Tensor) -> torch.Tensor:
    t_np = t.detach().cpu().numpy()
    inv_np = y_scaler.inverse_transform(t_np)
    return torch.from_numpy(inv_np).to(t.device)
