import torch
from .rigid_utils import Rigid
from .residue_constants import restype_order
import numpy as np
import pandas as pd
from .geometry import atom37_to_torsions, atom14_to_atom37, atom14_to_frames
import os


class MDGenDataset(torch.utils.data.Dataset):
    def __init__(self, args, split, repeat=1):
        super().__init__()
        self.df = pd.read_csv(split, index_col="name")
        self.args = args
        self.repeat = repeat

        if args.desres:
            self.desres_files = {}
            for name in self.df.index:
                protein_folder = os.path.join(self.args.data_dir, name)
                if os.path.isdir(protein_folder):
                    files = [
                        f
                        for f in os.listdir(protein_folder)
                        if f.endswith(f"{self.args.suffix}.npy")
                        and not f.startswith(".")
                    ]
                    if files:
                        self.desres_files[name] = files
                    else:
                        print(
                            f"Warning: No .npy files found in {protein_folder} for protein {name}"
                        )

    def __len__(self):
        if self.args.overfit_peptide:
            return 1000
        return self.repeat * len(self.df)

    def __getitem__(self, idx):
        idx = idx % len(self.df)
        if self.args.overfit:
            idx = 0

        if self.args.overfit_peptide is None:
            name = self.df.index[idx]
            seqres = self.df.seqres[name]
        else:
            name = self.args.overfit_peptide
            seqres = name
            if self.args.mdcath:
                seqres = self.df.seqres[name]

        if self.args.atlas:
            i = np.random.randint(1, 4)
            full_name = f"{name}_R{i}"
            file_path = f"{self.args.data_dir}/{full_name}{self.args.suffix}.npy"

        elif self.args.mdcath:
            length = 0
            tried_values = set()
            while length < self.args.num_frames + 1:
                i = np.random.randint(0, 4)
                tried_values.add(i)
                full_name = f"{name}_{self.args.data_temperature}_{i}"
                file_path = f"{self.args.data_dir}/{full_name}{self.args.suffix}.npy"
                arr = np.lib.format.open_memmap(file_path, "r")
                if self.args.frame_interval:
                    arr = arr[:: self.args.frame_interval]
                length = arr.shape[0]
                if tried_values == {0, 1, 2, 3, 4}:
                    raise ValueError(
                        f"Could not find a file with {self.args.num_frames} frames for {name}"
                    )
            file_path = f"{self.args.data_dir}/{full_name}{self.args.suffix}.npy"

        elif self.args.desres:
            protein_folder = os.path.join(self.args.data_dir, name)

            all_files = self.desres_files[name]
            tried_values = set()
            length = 0
            while length < self.args.num_frames + 1:
                chosen = np.random.choice(all_files)
                tried_values.add(chosen)
                file_path = os.path.join(protein_folder, chosen)
                arr = np.lib.format.open_memmap(file_path, "r")
                if self.args.frame_interval:
                    arr = arr[:: self.args.frame_interval]
                length = arr.shape[0]
                # If we've tried every available file without success, raise an error.
                if tried_values == set(all_files):
                    raise ValueError(
                        f"Could not find a file with {self.args.num_frames} frames for {name}"
                    )
            full_name = os.path.splitext(chosen)[0]

        else:
            full_name = name
            file_path = f"{self.args.data_dir}/{full_name}{self.args.suffix}.npy"

        try:
            arr = np.lib.format.open_memmap(file_path, "r")
        except ValueError as e:
            print(f"Error opening file {file_path}: {e}")
            # Re-raise the exception to stop execution or handle it
            raise
        except Exception as e:
            print(f"Unexpected error opening file {file_path}: {e}")
            raise

        if not self.args.mdcath:
            arr = np.lib.format.open_memmap(file_path, "r")
            #### DEBUGGGGGGGG
            # arr = arr[:1000]
            #### DEBUGGGGGGGG

            if self.args.frame_interval:
                arr = arr[:: self.args.frame_interval]
        frame_start = np.random.choice(np.arange(arr.shape[0] - self.args.num_frames))

        if self.args.overfit_frame:
            frame_start = 0
        end = frame_start + self.args.num_frames
        # arr = np.copy(arr[frame_start:end]) * 10 # convert to angstroms
        arr = np.copy(arr[frame_start:end]).astype(np.float32)  # / 10.0 # convert to nm
        if self.args.copy_frames:
            arr[1:] = arr[0]

        # arr should be in ANGSTROMS
        frames = atom14_to_frames(torch.from_numpy(arr))
        seqres = np.array([restype_order[c] for c in seqres])
        aatype = torch.from_numpy(seqres)[None].expand(self.args.num_frames, -1)
        atom37 = torch.from_numpy(atom14_to_atom37(arr, aatype)).float()

        L = frames.shape[1]
        mask = np.ones(L, dtype=np.float32)

        if self.args.no_frames:
            return {
                "name": full_name,
                "frame_start": frame_start,
                "atom37": atom37,
                "seqres": seqres,
                "mask": restype_atom37_mask[seqres],  # (L,)
            }
        torsions, torsion_mask = atom37_to_torsions(atom37, aatype)

        torsion_mask = torsion_mask[0]

        if self.args.atlas or self.args.mdcath:
            if L > self.args.crop:
                start = np.random.randint(0, L - self.args.crop + 1)
                torsions = torsions[:, start : start + self.args.crop]
                frames = frames[:, start : start + self.args.crop]
                seqres = seqres[start : start + self.args.crop]
                mask = mask[start : start + self.args.crop]
                torsion_mask = torsion_mask[start : start + self.args.crop]

            elif L < self.args.crop:
                pad = self.args.crop - L
                frames = Rigid.cat(
                    [
                        frames,
                        Rigid.identity(
                            (self.args.num_frames, pad),
                            requires_grad=False,
                            fmt="rot_mat",
                        ),
                    ],
                    1,
                )
                mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)])
                seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)])
                torsions = torch.cat(
                    [
                        torsions,
                        torch.zeros(
                            (torsions.shape[0], pad, 7, 2), dtype=torch.float32
                        ),
                    ],
                    1,
                )
                torsion_mask = torch.cat(
                    [torsion_mask, torch.zeros((pad, 7), dtype=torch.float32)]
                )

        return {
            "name": full_name,
            "frame_start": frame_start,
            "torsions": torsions,
            "torsion_mask": torsion_mask,
            "trans": frames._trans,
            "rots": frames._rots._rot_mats,
            "seqres": seqres,
            "mask": mask,  # (L,)
        }
