import os
import sys

root = os.path.abspath(".")
sys.path.insert(0, root)


import argparse
from typing import Dict, Tuple

import hydra
import lightning as L

import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from loguru import logger
from sklearn.decomposition import PCA

from proteinfoundation.partial_autoencoder.autoencoder import AutoEncoder

COLORS_RT = [
    "#FF0000",
    "#008000",
    "#0000FF",
    "#FFFF00",
    "#FFA500",
    "#800080",
    "#00FFFF",
    "#FF00FF",
    "#00FF00",
    "#FFC0CB",
    "#008080",
    "#E6E6FA",
    "#A52A2A",
    "#F5F5DC",
    "#800000",
    "#808000",
    "#FF7F50",
    "#000080",
    "#AAF0D1",
    "#FFDB58",
]


def parse_args_and_cfg() -> Tuple[Dict, Dict, str]:

    parser = argparse.ArgumentParser(description="Job info")
    parser.add_argument(
        "--config_name",
        type=str,
        default="inference_ae",
        help="Name of the config yaml file.",
    )
    parser.add_argument(
        "--config_number", type=int, default=-1, help="Number of the config yaml file."
    )
    parser.add_argument(
        "--config_subdir",
        type=str,
        help="(Optional) Name of directory with config files, if not included uses base inference config.\
            Likely only used when submitting to the cluster with script.",
    )
    args = parser.parse_args()

    if args.config_subdir is None:
        config_path = "../configs"
    else:
        config_path = f"../configs/{args.config_subdir}"

    with hydra.initialize(config_path, version_base=hydra.__version__):

        if args.config_number != -1:
            config_name = f"inf_{args.config_number}"
        else:
            config_name = args.config_name
        cfg = hydra.compose(config_name=config_name)
        logger.info(f"Inference config {cfg}")

    return args, cfg, config_name


def extract_ckpt_info(ckpt_file_path):
    ae_name = ckpt_file_path.split("/")[-3]
    ckpt_name = ckpt_file_path.split("/")[-1]
    return ae_name, ckpt_name


def setup(
    cfg: Dict,
    config_name: str,
    create_root: bool = True,
) -> str:

    logger.info(" ".join(sys.argv))

    assert torch.cuda.is_available(), "CUDA not available"
    logger.add(
        sys.stdout,
        format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}",
    )

    root_path = f"./inference/{config_name}"
    if create_root:
        os.makedirs(root_path, exist_ok=True)
    else:
        if not os.path.exists(root_path):
            raise ValueError("Results path %s does not exist" % root_path)

    logger.info(f"Seeding everything to seed {cfg.seed}")
    L.seed_everything(cfg.seed)

    return root_path


def load_dataloader(cfg):

    if cfg.dataset == "genie2":
        config_path = "../configs/dataset/afdb_fromraw"
        config_name = "genie2"
    elif cfg.dataset == "pdb":
        config_path = "../configs/dataset/pdb"
        config_name = "pdb_train"
    elif cfg.dataset == "pdb_multimer":
        config_path = "../configs/dataset/pdb_multimer"
        config_name = "pdb_multimer_train"
    else:
        raise ValueError(f"Dataset {cfg.dataset} not implemented")

    with hydra.initialize(config_path, version_base=hydra.__version__):
        cfg_data = hydra.compose(config_name=config_name)
        cfg_data["datamodule"]["batch_size"] = cfg.bs

    datamodule = hydra.utils.instantiate(cfg_data.datamodule)
    datamodule.prepare_data()
    datamodule.setup("fit")
    dataloader = datamodule.val_dataloader()
    print(
        f"Number of batches in dataloader: {len(dataloader)}, batch size: {cfg.bs}, total number of structures: {len(dataloader) * cfg.bs}"
    )
    return dataloader


def extract_pdb_ids(predictions):
    logger.info(f"Extracting PDBs we test on")
    vals = []
    for x_in, _ in predictions:
        v = x_in["id"]
        vals += v
    return vals


def compute_all_atom_rmsd(predictions, model):
    logger.info(f"Computing all-atom RMSD")
    vals = []
    for x_in, x_out in predictions:
        v = model.compute_struct_rec_loss(
            output_dec=x_out,
            batch=x_in,
        )["rmsd_no_align_a37_ang_justlog"]
        vals += v.tolist()
    return vals


def compute_sec_rec_rate(predictions, model):
    logger.info(f"Computing sequence recovery rate")
    vals = []
    for x_in, x_out in predictions:
        v = model.compute_seq_rec_loss(
            output_dec=x_out,
            batch=x_in,
        )["seq_rec_rate_justlog"]
        vals += v.tolist()
    return vals


def compute_kl_latent(predictions, model):
    logger.info(f"Computing sequence recovery rate")
    vals = []
    for _, x_out in predictions:
        v = model.compute_kl_penalty(
            mean=x_out["mean"],
            log_scale=x_out["log_scale"],
            mask=x_out["residue_mask"],
            w=1.0,
        )["kl_now_justlog"]
        vals += v.tolist()
    return vals


def compute_metric(metric, predictions, model):
    if metric == "all_atom_rmsd":
        return compute_all_atom_rmsd(predictions, model)
    elif metric == "seq_rec_rate":
        return compute_sec_rec_rate(predictions, model)
    elif metric == "kl_latent_dist":
        return compute_kl_latent(predictions, model)
    else:
        raise IOError(f"Metric {metric} not implemented")


def get_df_stats(df):
    numeric_cols = [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])]

    means = df[numeric_cols].mean()
    stds = df[numeric_cols].std()

    stats_data = {"stat_type": ["mean", "std"]}
    for col in numeric_cols:
        stats_data[col] = [means[col], stds[col]]

    return pd.DataFrame(stats_data)


def main() -> None:
    load_dotenv()

    args, cfg, config_name = parse_args_and_cfg()
    ae_name, ckpt_name = extract_ckpt_info(cfg.ckpt_file)

    root_path = setup(cfg, create_root=True, config_name=config_name)
    df_file_store = os.path.join(root_path, f"../results_{config_name}.csv")
    df_file_store_summary = os.path.join(
        root_path, f"../results_{config_name}_summary.csv"
    )

    dataloader = load_dataloader(cfg)

    model = AutoEncoder.load_from_checkpoint(cfg.ckpt_file)

    trainer = L.Trainer(
        accelerator="gpu", devices=1, limit_predict_batches=int(cfg.n_structs / cfg.bs)
    )
    predictions = trainer.predict(model, dataloader)

    metrics = {}
    metrics_to_compute = [k for k in cfg.metrics if cfg.metrics[k]]
    for metric in metrics_to_compute:
        metrics[metric] = compute_metric(
            metric=metric, predictions=predictions, model=model
        )

    pdb_id = extract_pdb_ids(predictions)

    dir_storages = {}

    info_df = {"pdb_id": pdb_id}
    info_df.update(metrics)
    df = pd.DataFrame(info_df)

    col_names = ["ae_name", "ckpt_name", "dataset"]
    values = [ae_name, ckpt_name, cfg.dataset]
    for m in metrics:
        col_names += [f"{m}_mean", f"{m}_std", f"{m}_max", f"{m}_min"]
        vals_aux = np.array(metrics[m])
        values += [vals_aux.mean(), vals_aux.std(), vals_aux.max(), vals_aux.min()]
    col_names += [k for k in dir_storages]
    values += [dir_storages[k] for k in dir_storages]
    df_summary = pd.DataFrame(
        {col_names[i]: [values[i]] for i in range(len(col_names))}
    )

    df.to_csv(df_file_store, index=False)
    df_summary.to_csv(df_file_store_summary, index=False)

    df.to_csv(df_file_store, index=False)
    df_summary.to_csv(df_file_store_summary, index=False)
    print("Done saving dataframes")


if __name__ == "__main__":
    main()
