import os
import torch
import h5py
import json
import random
import numpy as np
from torch.utils.data import Dataset
from src.data import FrameData
from tqdm import tqdm
from src.utils.amino_acid_vocab import AA_1_TO_ID
from src.data.constants import (
    MDCATH_ALL_REPLICAS,
    MDCATH_ALL_TEMPS,
    MDCATH_FRAME_SPACING,
    DEEPSEEK_CLASSIFICATION_MAP,
    DEEPSEEK_CONFIDENCE_MAP,
)


class MDCATH(Dataset):
    def __init__(
        self,
        dataset_path,
        protein_names,
        seq_emb_name,
        temperatures: list = MDCATH_ALL_TEMPS,
        max_lag: int = 200,  # in ns
        samples_per_epoch: int = 10000,
    ):
        super().__init__()

        self.protein_names = protein_names
        self.temperatures = temperatures
        self.dataset_path = dataset_path
        self.seq_emb_name = seq_emb_name
        self.max_lag = max_lag
        self.samples_per_epoch = samples_per_epoch
        self.max_lag_frames = self.max_lag // MDCATH_FRAME_SPACING
        with open(os.path.join(dataset_path, "metadata.json"), "r") as f:
            self.metadata = json.load(f)

        self.seq_emb_per_domain = torch.load(
            os.path.join(dataset_path, self.seq_emb_name)
        )

        self.all_coords = {}
        for protein_name in tqdm(self.protein_names):
            h5_path = f"{dataset_path}/data/mdcath_dataset_{protein_name}.h5"
            self.all_coords[protein_name] = {}
            with h5py.File(h5_path, "r") as f:
                for temp in self.temperatures:
                    self.all_coords[protein_name][temp] = {}
                    for repl in f[protein_name][temp].keys():
                        grp = f[protein_name][temp][repl]
                        coords_all = grp["ca_coords"][:]
                        coords = torch.from_numpy(np.array(coords_all)) / 10.0
                        self.all_coords[protein_name][temp][repl] = coords

    def __len__(self):
        return self.samples_per_epoch

    def __getitem__(self, idx):
        protein = random.choice(self.protein_names)
        metadata = self.metadata[protein]
        temp = random.choice(self.temperatures)
        repl = random.choice(MDCATH_ALL_REPLICAS)
        n_frames = self.all_coords[protein][temp][repl].shape[0]
        residue_ids = [AA_1_TO_ID[aa] for aa in metadata["sequence"]]
        idx0 = random.randint(0, n_frames - 2)
        max_valid_lag = min(
            n_frames - idx0 - 1,
            self.max_lag_frames,
        )
        lag = random.randint(1, max_valid_lag)
        x0 = self.all_coords[protein][temp][repl][idx0]
        xt = self.all_coords[protein][temp][repl][idx0 + lag]

        return FrameData(
            id=protein,
            x0=x0,
            xt=xt,
            lag=lag,
            temp=torch.tensor(int(temp), dtype=torch.long),
            residue_ids=torch.tensor(residue_ids, dtype=torch.long),
            sequence_emb=self.seq_emb_per_domain[protein],
            chain_breaks_per_residue=torch.tensor(
                metadata["chain_breaks"], dtype=torch.long
            ),
            residue_pdb_idx=torch.tensor(metadata["residue_pdb_idx"], dtype=torch.long),
            cath_code=[metadata["cath_code"]],
            deepseek_classification=torch.tensor(
                DEEPSEEK_CLASSIFICATION_MAP[metadata["deepseek_classification"]],
                dtype=torch.long,
            ),
            deepseek_confidence=torch.tensor(
                DEEPSEEK_CONFIDENCE_MAP[metadata["deepseek_confidence"]],
                dtype=torch.long,
            ),
        )
