import os
import os.path as osp

import wandb
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

from model_selection import ModelSelector
from model_selection.utils import get_mse_loss, get_predictions, compute_rmse
from utils import load_model, set_seed
from data import get_test_data

# params
ENTITY = "..."
PROJECT = "..."
EXP_ID = {
    "config.logging.experiment_id": "...",
}
MODEL_CKP_DIR = (
    "..."
)
MODEL_SELECTION_ALGS = ["IWV", "DEV", "SB", "TB"]

BATCH_SIZE = 16
DEVICE = "cuda"

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
set_seed(42)

print(f"Evaluating models for: {PROJECT}")

if __name__ == "__main__":
    api = wandb.Api()
    all_runs = api.runs(f"{ENTITY}/{PROJECT}", filters=EXP_ID)

    for i, run in enumerate(all_runs):
        checkpoint_path = osp.join(
            MODEL_CKP_DIR,
            run.name,
            "best.pth",
        )


    # load datasets once
    checkpoint_path = osp.join(
        MODEL_CKP_DIR,
        all_runs[0].name,
        "best.pth",
    )
    ckp_dict = load_model(checkpoint_path, load_opt=False, load_trainset=True)
    trainset_source = ckp_dict["trainset_source"]
    valset_source = ckp_dict["valset_source"]
    trainset_target = ckp_dict["trainset_target"]
    testset = get_test_data(ckp_dict["cfg"], trainset_source.normalization_stats)
    testset_source, testset_target = testset

    # to evaluate whole domain for heatsink
    if PROJECT == "heatsink":
        trainset_source.n_subsampled_nodes = None
        valset_source.n_subsampled_nodes = None
        trainset_target.n_subsampled_nodes = None
        testset_source.n_subsampled_nodes = None
        testset_target.n_subsampled_nodes = None

    # get unqiue models, da_algorithms and seeds
    unique_models = sorted({run.config["model"]["name"] for run in all_runs})
    unique_da_algorithms = sorted(
        {run.config["da_algorithm"]["name"] for run in all_runs}
    )
    unique_seeds = sorted({run.config["seed"] for run in all_runs})

    results = []
    for model_type_name in tqdm(unique_models, desc="Models"):
        for da_algorithm_type_name in tqdm(
            unique_da_algorithms, desc="DA Algorithms", leave=False
        ):
            for seed in tqdm(unique_seeds, desc="Seeds", leave=False):
                # continue
                # get the runs for the different hyperparams lambda (exclude lambda = 0)
                filters = EXP_ID.copy()
                filters.update(
                    {
                        "config.seed": seed,
                        "config.model.name": model_type_name,
                        "config.da_algorithm.name": da_algorithm_type_name,
                        "config.da_algorithm.da_loss_weight": {"$ne": 0},
                    }
                )
                # # TODO: quick fix for transolver models
                # if model_type_name == "Transolver":
                #      filters["config.model.hparams.transolver_base"] = 128
                runs = api.runs(f"{ENTITY}/{PROJECT}", filters=filters)
                try:
                    if len(runs) == 0:
                        continue
                except Exception as e:
                    print("an error occured")
                    continue
                # load models
                models = []  # {model: config]}
                cfgs = []
                for i, run in enumerate(runs):
                    checkpoint_path = osp.join(
                        MODEL_CKP_DIR,
                        run.name,
                        "best.pth",
                    )
                    ckp_dict = load_model(
                        checkpoint_path, load_opt=False, valset=valset_source
                    )
                    models.append(ckp_dict["model"])
                    cfgs.append(ckp_dict["cfg"])
                # model selection (aggregation and test loss)
                model_selector = ModelSelector(
                    algorithm_names=MODEL_SELECTION_ALGS,
                    candidate_models=models,
                    trainset_source=trainset_source,
                    valset_source=valset_source,
                    trainset_target=trainset_target,
                    testset_source=testset_source,
                    testset_target=testset_target,
                    batch_size=BATCH_SIZE,
                    device=DEVICE
                )
                model_selector.compute_model_weights()
                # test_loss_source_per_alg: [n_selection_algorithms]
                # rmse_source_per_field: [n_selection_algorithms, n_fields]
                # rmse_source_deformation: [n_selection_algorithms]
                (
                    test_loss_source_per_alg,
                    test_loss_target_per_alg,
                    rmse_source_per_field,
                    rmse_target_per_field,
                    rmse_source_deformation,
                    rmse_target_deformation,
                ) = model_selector.compute_test_performance()

                # record results
                for i, algo_name in enumerate(MODEL_SELECTION_ALGS):
                    result = {
                        "model_name": model_type_name,
                        "da_algorithm_name": da_algorithm_type_name,
                        "model_selection_algorithm_name": algo_name,
                        "seed": seed,
                        "test_loss_source": test_loss_source_per_alg[i].item(),
                        "test_loss_target": test_loss_target_per_alg[i].item(),
                    }
                    # add deformation rmse
                    result["test_loss_source_deformation"] = rmse_source_deformation[i].item()
                    result["test_loss_target_deformation"] = rmse_target_deformation[i].item()
                    # add rmse for each field
                    for field_name, field_slice in valset_source.channels.items():
                        result[f"test_loss_source_{field_name}"] = (
                            rmse_source_per_field[i, field_slice].mean().item()
                        )
                        result[f"test_loss_target_{field_name}"] = (
                            rmse_target_per_field[i, field_slice].mean().item()
                        )
                    results.append(result)

        # for every model type, include results for an unregularized run (da_loss_weight=0)
        for seed in unique_seeds:
            filters = EXP_ID.copy()
            filters.update(
                {
                    "config.seed": seed,
                    "config.model.name": model_type_name,
                    "config.da_algorithm.name": "deep_coral",  # depending on where you run without da_loss
                    "config.da_algorithm.da_loss_weight": 0,
                }
            )

            runs = api.runs(f"{ENTITY}/{PROJECT}", filters=filters)
            try:
                if len(runs) == 0:
                    continue
            except Exception as e:
                print("an error occured")
                continue
            assert len(runs) == 1

            # load model
            checkpoint_path = osp.join(
                MODEL_CKP_DIR,
                runs[0].name,
                "best.pth",
            )
            ckp_dict = load_model(checkpoint_path, load_opt=False, valset=valset_source)
            model = ckp_dict["model"]
            cfg = ckp_dict["cfg"]

            source_loader = DataLoader(
                testset_source,
                batch_size=BATCH_SIZE,
                collate_fn=testset_source.collate,
            )
            target_loader = DataLoader(
                testset_target,
                batch_size=BATCH_SIZE,
                collate_fn=testset_target.collate,
            )

            source_preds, source_batch_index = get_predictions(
                model, source_loader, device=DEVICE
            )
            source_preds = source_preds.to("cpu")
            source_batch_index = source_batch_index.to("cpu")
            target_preds, target_batch_index = get_predictions(
                model, target_loader, device=DEVICE
            )
            target_preds = target_preds.to("cpu")
            target_batch_index = target_batch_index.to("cpu")

            torch.cuda.empty_cache()
 
            source_loader_long = DataLoader(
                testset_source,
                batch_size=len(testset_source),
                collate_fn=testset_source.collate,
            )
            target_loader_long = DataLoader(
                testset_target,
                batch_size=len(testset_target),
                collate_fn=testset_target.collate,
            )
            # source_sample_val = next(iter(source_loader_val))
            # source_sample_val = source_sample_val.to("cpu")
            source_sample = next(iter(source_loader_long))
            source_sample = source_sample.to("cpu")
            target_sample = next(iter(target_loader_long))
            target_sample = target_sample.to("cpu")
            # source_gt_val = torch.cat([source_sample_val.y, source_sample_val.y_mesh_coords], dim=-1)
            source_gt = torch.cat([source_sample.y, source_sample.y_mesh_coords], dim=-1)
            target_gt = torch.cat([target_sample.y, target_sample.y_mesh_coords], dim=-1)

            # 1) compute losses: normalized mse losses (across all fields and positions)
            source_loss = (source_preds[..., :-2] - source_gt[..., :-2])**2
            # source_loss_per_node = source_loss.mean(dim=-1)
            source_loss_per_sample = torch.zeros([source_sample.batch_index.max().item()+1, source_loss.shape[-1]]).to("cpu")
            source_batch_index_expanded = source_batch_index.unsqueeze(-1).expand(source_loss.shape)
            source_loss_per_sample.scatter_reduce_(dim=0, index=source_batch_index_expanded, src=source_loss, reduce="mean")
            source_loss = source_loss_per_sample.sqrt().mean(dim=(0,1))
            # source_loss = source_loss.sqrt()

            source_loss_new, source_loss_positions, source_loss_new_denormalized, source_loss_positions_denormalized = compute_rmse(model, source_loader, True, device=DEVICE)

            source_loss_new = source_loss_new.to("cpu")
            source_loss_positions = source_loss_positions.to("cpu")
            source_loss_new_denormalized = source_loss_new_denormalized.to("cpu")
            source_loss_positions_denormalized = source_loss_positions_denormalized.to("cpu")


            target_loss = (target_preds[..., :-2] - target_gt[..., :-2])**2
            # target_loss_per_node = target_loss.mean(dim=-1)
            target_loss_per_sample = torch.zeros([target_sample.batch_index.max().item()+1, target_loss.shape[-1]]).to("cpu")
            target_batch_index_expanded = target_batch_index.unsqueeze(-1).expand(target_loss.shape)
            target_loss_per_sample.scatter_reduce_(dim=0, index=target_batch_index_expanded, src=target_loss, reduce="mean")
            target_loss = target_loss_per_sample.sqrt().mean(dim=(0,1))
            # target_loss = target_loss.sqrt()

            target_loss_new, target_loss_positions, target_loss_new_denormalized, target_loss_positions_denormalized = compute_rmse(model, target_loader, True, device=DEVICE)
            target_loss_new = target_loss_new.to("cpu")
            target_loss_positions = target_loss_positions.to("cpu")
            target_loss_new_denormalized = target_loss_new_denormalized.to("cpu")
            target_loss_positions_denormalized = target_loss_positions_denormalized.to("cpu")

            # 2) compute losses: denormalized RMSE for each field and denormalized coordinates loss
            # denormalize predictions
            # ensemble_predictions: [n_algorithms, total_nodes, n_fields]
            predictions_source_denormalized = testset_source.denormalize(None, source_preds[..., :-source_sample.y_mesh_coords.shape[-1]])  # slice to remove coordinates
            predictions_target_denormalized = testset_target.denormalize(None, target_preds[..., :-target_sample.y_mesh_coords.shape[-1]])

            # denormalize ground truth
            # gt: [total_nodes, n_fields]
            source_gt_denormalized = testset_source.denormalize(None, source_gt[..., :-source_sample.y_mesh_coords.shape[-1]])
            target_gt_denormalized = testset_target.denormalize(None, target_gt[..., :-target_sample.y_mesh_coords.shape[-1]])

            # RMSE across nodes
            mse_src_fields = (predictions_source_denormalized - source_gt_denormalized)**2
            mse_src_fields_per_sample = torch.zeros([len(testset_source), mse_src_fields.shape[-1]], device="cpu")
            source_batch_index_expanded = source_batch_index.unsqueeze(-1).expand(mse_src_fields.shape)
            mse_src_fields_per_sample.scatter_reduce_(dim=0, index=source_batch_index_expanded, src=mse_src_fields, reduce="mean")
            rmse_src_fields = mse_src_fields_per_sample.sqrt().mean(0)
            # assert torch.allclose(rmse_src_fields, source_loss_new_denormalized.mean(0))
            # rmse_src_fields = mse_src_fields.sqrt()

            mse_tgt_fields = ((predictions_target_denormalized - target_gt_denormalized)**2)
            mse_tgt_fields_per_sample = torch.zeros([len(testset_target), mse_tgt_fields.shape[-1]], device="cpu")
            target_batch_index_expanded = target_batch_index.unsqueeze(-1).expand(mse_tgt_fields.shape)
            mse_tgt_fields_per_sample.scatter_reduce_(dim=0, index=target_batch_index_expanded, src=mse_tgt_fields, reduce="mean")
            rmse_tgt_fields = mse_tgt_fields_per_sample.sqrt().mean(0)
            # rmse_tgt_fields = mse_tgt_fields.sqrt()

            # denormalize pred coords
            ensemble_coords_source_denormalized = testset_source.denormalize_coords(source_preds[..., -source_sample.y_mesh_coords.shape[-1]:])
            ensemble_coords_target_denormalized = testset_target.denormalize_coords(target_preds[..., -target_sample.y_mesh_coords.shape[-1]:])

            # denormalize gt coords
            source_gt_coords_denormalized = testset_source.denormalize_coords(source_gt[..., -source_sample.y_mesh_coords.shape[-1]:])
            target_gt_coords_denormalized = testset_target.denormalize_coords(target_gt[..., -target_sample.y_mesh_coords.shape[-1]:])

            # squared‐error per node, sum over coord_dim, then mean over nodes
            coord_rmse_src = ((ensemble_coords_source_denormalized - source_gt_coords_denormalized)**2).sum(dim=-1).sqrt()
            rmse_src_coords_per_graph = torch.zeros([len(testset_source)], device="cpu")
            rmse_src_coords_per_graph.scatter_reduce_(dim=0, index=source_batch_index, src=coord_rmse_src, reduce="mean")
            rmse_src_deformation = rmse_src_coords_per_graph.mean()

            coord_rmse_tgt = ((ensemble_coords_target_denormalized - target_gt_coords_denormalized)**2).sum(dim=-1).sqrt()
            rmse_tgt_coords_per_graph = torch.zeros([len(testset_target)], device="cpu")
            rmse_tgt_coords_per_graph.scatter_reduce_(dim=0, index=target_batch_index, src=coord_rmse_tgt, reduce="mean")
            rmse_tgt_deformation = rmse_tgt_coords_per_graph.mean()
            # assert torch.allclose(rmse_tgt_deformation, target_loss_positions_denormalized.mean())

            # record result
            result = {
                "model_name": model_type_name,
                "da_algorithm_name": "-",
                "model_selection_algorithm_name": "-",
                "seed": seed,
                "test_loss_source": source_loss.item(),
                "test_loss_target": target_loss.item(),
            }
            # add deformation rmse
            result["test_loss_source_deformation"] = rmse_src_deformation.item()
            result["test_loss_target_deformation"] = rmse_tgt_deformation.item()
            # add rmse for each field
            for field_name, field_slice in valset_source.channels.items():
                 result[f"test_loss_source_{field_name}"] = (
                      rmse_src_fields[field_slice].mean().item()
                 )
                 result[f"test_loss_target_{field_name}"] = (
                      rmse_tgt_fields[field_slice].mean().item()
                 )
            results.append(result)
            # clear gpu memory
            torch.cuda.empty_cache()

    # Create a Pandas DataFrame with the results
    os.makedirs("./results/model_selection_results", exist_ok=True)
    df_results = pd.DataFrame(results)
    df_results.to_pickle(f"./results/model_selection_results/results_{PROJECT}.pkl")
    df_results.to_csv(f"./results/model_selection_results/results_{PROJECT}.csv", index=False)
