import argparse
import glob
import os

import numpy as np
import pandas as pd
import torch
import wandb
from aion_eval.benchmarks.desiddpayne.dataset import DESIDDPayneDatasetModule
from aion_eval.benchmarks.desiddpayne.models import (
    AIONBaselineSpectrumModel,
    AIONCrossAttentionProbing,
    AIONLinearProbing,
)
from aion_eval.utils import flatten_dict, index_collated
from astropy.table import Table
from torch.utils._pytree import tree_map
from tqdm import tqdm

torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)


# WandB API
wandb_api = wandb.Api()
project_name = "aion_eval_desiddpayne_scaling"
entity_name = ""


def batch_process_catalog(
    catalog, model, input_keys, output_keys, means, stds, batch_size=128
):
    """
    Processes a catalog of data in batches using a specified model.

    Parameters:
        catalog (dict): A dictionary containing the data to be processed. Each key corresponds to a different input feature.
        model (torch.nn.Module): The model to be used for processing the data.
        input_keys (list): A list of keys to be used from the catalog for processing.
        batch_size (int, optional): The size of each batch. Default is 512.

    Returns:
        np.ndarray: An array of predictions generated by the model for the entire catalog.
    """
    predictions = []
    len_cat = len(catalog[list(catalog.keys())[0]])

    num_batches = len_cat // batch_size + (1 if len_cat % batch_size != 0 else 0)

    with torch.no_grad():
        for i in tqdm(range(num_batches)):
            # Prepare batch
            batch = {
                k: index_collated(
                    catalog[k],
                    np.arange(i * batch_size, min((i + 1) * batch_size, len_cat)),
                )
                for k in input_keys
            }
            batch = tree_map(lambda x: torch.tensor(x).to("cuda"), batch)
            batch = flatten_dict(batch)

            # Apply model
            res = model(batch).cpu().numpy()
            res = res * stds[None, :] + means[None, :]

            predictions.append(res)

    predictions = np.concatenate(predictions, axis=0)
    predictions = {k: predictions[:, i] for i, k in enumerate(output_keys)}
    predictions = Table(predictions)
    return predictions


def get_model(run, run_id):
    model_name = run.config["model"]["class_path"]
    model_path = (
        run.config["trainer"]["default_root_dir"]
        + f"/{project_name}/{run_id}/checkpoints/*.ckpt"
    )
    print(f"Loading model {model_name} from {model_path}")

    model_path = glob.glob(model_path)[0]  # There should only be one model checkpoint
    if "AIONLinearProbing" in model_name:
        model = AIONLinearProbing.load_from_checkpoint(model_path)
    elif "AIONCrossAttentionProbing" in model_name:
        model = AIONCrossAttentionProbing.load_from_checkpoint(model_path)
    elif "AIONBaselineSpectrumModel" in model_name:
        model = AIONBaselineSpectrumModel.load_from_checkpoint(model_path)
    else:
        raise ValueError(f"Model {model_name} not implemented in eval script.")

    model = model.eval()
    model = model.to("cuda")
    return model


def experiment(
    models_to_evaluate, output_dir, overwrite: bool = False, version: str = "1"
):
    """
    This experiment evaluates the adapted models on the task they were adapted for
    """
    dm = DESIDDPayneDatasetModule(data_dir="data", num_workers=0)
    dm.setup(None)

    catalog = dm.val_data

    os.makedirs(output_dir, exist_ok=True)

    for run_id in models_to_evaluate["ID"]:
        run = wandb_api.run(f"{entity_name}/{project_name}/{run_id}")

        print(f"Processing run {run_id}")
        # Check if the file has already been processed, if so, skip it
        if os.path.exists(f"{output_dir}/{run.name}_{run_id}.fits") and not overwrite:
            print(f"Run {run_id} already processed, skipping")
            continue

        # Load the model
        model = get_model(run, run_id)

        if "limit_train_size" in run.config["data"]["init_args"]:
            train_size = run.config["data"]["init_args"]["limit_train_size"]
        else:
            train_size = 100000000
        output_fields = run.config["data"]["init_args"]["output_fields"]

        _means = {k: dm.train_data[k][:train_size].mean() for k in output_fields}
        _stds = {k: dm.train_data[k][:train_size].std() for k in output_fields}
        means = np.array(list(_means.values()))
        stds = np.array(list(_stds.values()))

        # Process the catalog, providing the outputs the model expects
        predictions = batch_process_catalog(
            catalog,
            model,
            run.config["data"]["init_args"]["input_fields"],
            output_fields,
            means,
            stds,
        )

        predictions.write(f"{output_dir}/{run.name}_{run_id}.fits", overwrite=True)


def main(args):
    # Load the CSV file containing the run IDs
    models_to_evaluate = pd.read_csv(args.wandb_csv_file)

    # Run the experiment
    experiment(models_to_evaluate, args.output_dir, args.overwrite, args.version)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run model inference on a catalog of data."
    )
    parser.add_argument(
        "--wandb_csv_file",
        type=str,
        default="scripts/desiddpayne_runs_v1.csv",
        help="csv file of all the runs we want to analyse.",
    )
    parser.add_argument(
        "--overwrite", action="store_true", help="Overwrite existing files if true."
    )
    parser.add_argument(
        "--version", type=str, default="1", help="Version of the catalog to use."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="data/analysis/desiddpayne_scaling",
        help="Output directory for the results.",
    )
    args = parser.parse_args()
    main(args)
