import wandb
import torch
import hydra
import os
import rootutils
import lightning as L
import h5py
import numpy as np
import pickle

from omegaconf import DictConfig, OmegaConf
from src.models import DynamicsCFM
from src.utils.amino_acid_vocab import AA_1_TO_ID

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)


def get_wandb_checkpoint_path(
    entity: str,
    project: str,
    run_id: str,
    ckpt_name: str,
) -> str:
    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")
    output_dir = run.config["paths"]["output_dir"]
    checkpoint_dir = os.path.join(output_dir, "checkpoints")
    return os.path.join(checkpoint_dir, ckpt_name), run


@hydra.main(
    version_base=None,
    config_path="../configs",
    config_name="generate_fast_folders.yaml",
)
def generate_fast_folders(cfg: DictConfig):
    torch.set_float32_matmul_precision("medium")
    torch.backends.cudnn.benchmark = True
    L.seed_everything(12345, workers=True)

    # Set GPU device - either None (CPU) or a single GPU integer
    if cfg.gpu_device is not None:
        device = torch.device(f"cuda:{cfg.gpu_device}")
    else:
        device = torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(cfg.gpu_device)

    checkpoint_path, run = get_wandb_checkpoint_path(
        entity=cfg.wandb.entity,
        project=cfg.wandb.project,
        run_id=cfg.wandb.run_id,
        ckpt_name=cfg.wandb.ckpt_name,
    )

    logger_config = OmegaConf.to_container(cfg, resolve=True)
    wandb.init(
        entity=cfg.wandb.entity,
        project=cfg.wandb.target_project,
        name=cfg.wandb.target_name,
        config=logger_config,
        dir=cfg.wandb.dir,
    )

    velocity_net = hydra.utils.instantiate(
        run.config["model"]["velocity_net"], use_compile=True
    )
    structure_net = hydra.utils.instantiate(
        run.config["model"]["structure_net"], use_compile=True
    )

    if structure_net.finetune == 0:
        structure_net.load_pretrained_weights(
            run.config["model"]["structure_net"]["weights_path"]
        )

    model = DynamicsCFM.load_from_checkpoint(
        checkpoint_path,
        velocity_net=velocity_net,
        structure_net=structure_net,
        strict=True,
        map_location=device,
    )
    model = model.to(device)
    model.eval()

    dataset_path = f"{cfg.paths.project_data_dir}/fast_folders"
    h5_file_path = f"{dataset_path}/{cfg.protein_name}_xyz_aligned_to_folded_state.h5"
    with h5py.File(h5_file_path, "r") as f:
        coords = np.array(f["ca_xyz_0000"])
    metadata = torch.load(f"{dataset_path}/metadata.pt")[cfg.protein_name]
    temperature = metadata["temperature"]

    sequence_emb = metadata["esmc_6b_seq_embeddings"]
    residue_ids = [AA_1_TO_ID[aa] for aa in metadata["sequence"]]

    num_res = coords.shape[1]

    saved_tica_dir = f"{dataset_path}/reference_MSMs"
    with open(f"{saved_tica_dir}/{cfg.protein_name}.pkl", "rb") as f:
        msm_model = pickle.load(f)

    if cfg.start_frames == "folded":
        start_indices = [
            int(msm_model["start_frame_folded"])
        ] * cfg.samples_per_iteration
    elif cfg.start_frames == "unfolded":
        start_indices = [
            int(msm_model["start_frame_unfolded"])
        ] * cfg.samples_per_iteration
    else:
        raise ValueError(f"Unknown start_frames option: {cfg.start_frames}")

    x0 = torch.stack(
        [torch.tensor(coords[idx], dtype=torch.float32) for idx in start_indices],
        dim=0,
    )

    sequence_emb = sequence_emb.unsqueeze(0).repeat(cfg.samples_per_iteration, 1, 1)
    rest_conditions = {
        "lag": torch.tensor(
            [cfg.step] * cfg.samples_per_iteration, dtype=torch.long
        ).to(device),
        "temp": torch.tensor(
            [temperature] * cfg.samples_per_iteration, dtype=torch.long
        ).to(device),
        "residue_ids": torch.tensor(
            [residue_ids] * cfg.samples_per_iteration, dtype=torch.long
        ).to(device),
        "sequence_emb": sequence_emb.to(device),
        "deepseek_classification": torch.ones(
            cfg.samples_per_iteration, dtype=torch.long
        ).to(device),
        "deepseek_confidence": torch.ones(
            cfg.samples_per_iteration, dtype=torch.long
        ).to(device)
        * 2,
        "deepseek_evidence_emb": torch.zeros(
            cfg.samples_per_iteration, 2048, dtype=torch.float32
        ).to(device),
    }
    if "cath_code" in metadata:
        print(f"Using cath_code: {metadata['cath_code']}")
        rest_conditions["cath_code"] = np.array(
            [[metadata["cath_code"]]] * cfg.samples_per_iteration
        )
    generated_coords = []
    for i in range(cfg.number_of_iterations):
        generated_coords.append(
            model.generate_trajectory(
                x0=x0.to(device),
                ode_steps=cfg.integrator_steps,
                ode_method=cfg.method,
                trajectory_steps=cfg.number_of_steps,
                return_intermediates=cfg.save_intermediate_steps,
                **rest_conditions,
            )
        )

    if cfg.save_intermediate_steps:
        generated_coords = torch.stack(generated_coords).reshape(
            -1, cfg.number_of_steps + 1, num_res, 3
        )
    else:
        generated_coords = torch.stack(generated_coords).reshape(-1, num_res, 3)

    coords_path = os.path.join(cfg.wandb.dir, "generated_coords.pt")
    torch.save(generated_coords, coords_path)
    wandb.finish()


if __name__ == "__main__":
    generate_fast_folders()
