import torch
from .rigid_utils import Rigid
from .residue_constants import restype_order
import numpy as np
import pandas as pd
from .geometry import atom37_to_torsions, atom14_to_atom37, atom14_to_frames
import random
import tqdm
import os
import deeptime
from mdgen.optimal_transport import OptimalTransportSolver
from . import residue_constants as rc


class MDGenDataset(torch.utils.data.Dataset):
    def __init__(self, args, split, repeat=1):
        super().__init__()
        self.df = pd.read_csv(split, index_col="name")
        self.args = args
        self.repeat = repeat
        self.num_transitions_per_traj = args.num_samples_per_cluster * args.num_clusters
        self.num_samples_per_cluster = args.num_samples_per_cluster
        self.num_clusters_to_sample = args.num_clusters
        self.cluster_sampling_mode = args.cluster_sampling_mode
        self.x0_sampling_mode = args.x0_sampling_mode

        self.files = {}
        self.number_of_one_state_proteins = 0

        for i, name in tqdm.tqdm(enumerate(self.df.index)):
            if self.args.overfit_peptide:
                name = self.args.overfit_peptide

            # Determine which "path_variables" (5 or 10) to use (inconsistency while naming file fix)
            path_variables = self._determine_path_variables(name)

            file_paths, msm_cluster_files = self._construct_file_paths(
                name, path_variables
            )

            if not self._validate_files(file_paths):
                print(f"Missing trajectory file(s) for {name}. Skipping...")
                continue

            # Load trajectory arrays and corresponding MSM cluster arrays (or dummy arrays if missing)
            arr_list, clusters_list = self._load_arrays_and_clusters(
                file_paths, msm_cluster_files
            )

            # Identify slices that contain only a single unique state. ()
            arr_list, clusters_list, single_state = self._filter_single_state(
                arr_list, clusters_list, name
            )
            if len(arr_list) == 0:
                print(
                    f"Protein {name} has only one state and msm_include_single_state is set to False"
                )
                self.number_of_one_state_proteins += 1
                continue

            ## Build MSM transition matrix
            transition_matrix = self._build_transition_matrix(clusters_list)

            ## Compute indices (members) for each cluster in every trajectory slice
            cluster_members_list = self._compute_clusters_members(
                clusters_list, transition_matrix
            )

            ## Compute the sampling probabilities based on the chosen mode
            probabilities = self._compute_probabilities(transition_matrix, single_state)

            ## Set probabilities to zero for any clusters missing from the data and return unique clusters
            probabilities, unique_clusters = self._fix_missing_clusters(
                probabilities, clusters_list
            )

            self.files[name] = {
                "arr_paths": arr_list,  # <-- paths, not memmaps
                "clusters": clusters_list,
                "probabilities": probabilities,
                "cluster_members": cluster_members_list,
                "unique_clusters": unique_clusters,
            }

            if all(single_state):
                self.number_of_one_state_proteins += 1

            if self.args.overfit_peptide:
                break

        if self.args.overfit_peptide:
            self.valid_names = [self.args.overfit_peptide]
        else:
            self.valid_names = list(self.files.keys())

        print(
            "Number of proteins with only one state:", self.number_of_one_state_proteins
        )

    def _determine_path_variables(self, name):
        """Determine whether to use 5 or 10 for file naming based on file existence."""
        path_file = f"{self.args.data_dir}/{name}_transition_matrix_5_{self.args.data_temperature}.npy"
        return 5 if os.path.exists(path_file) else 10

    def _construct_file_paths(self, name, path_variables):
        """Construct file paths for trajectory slices and MSM cluster files."""
        file_paths = [
            f"{self.args.data_dir}/{name}_{self.args.data_temperature}_{i}.npy"
            for i in range(5)
        ]
        msm_cluster_files = []

        if self.args.msm_observables and self.args.msm_observables.strip():
            selected_observables = [
                feat.strip()
                for feat in self.args.msm_observables.split(",")
                if feat.strip()
            ]
        else:
            selected_observables = []

        for i in range(5):
            if self.args.msm_vampnet:
                msm_cluster_file = f"{self.args.data_dir}/{name}_vampnet_{self.args.data_temperature}_{i}.npy"
            elif selected_observables:
                feature_label = "_".join(selected_observables)
                if self.args.msm_num_clusters == 5:
                    msm_cluster_file = f"{self.args.data_dir}/{name}_{feature_label}_{self.args.data_temperature}_{i}.npy"
                else:
                    msm_cluster_file = (
                        f"{self.args.data_dir}/{name}_{feature_label}_"
                        f"{self.args.msm_num_clusters}_{self.args.data_temperature}_{i}.npy"
                    )
            else:
                msm_cluster_file = (
                    f"{self.args.data_dir}/{name}_msm_cluster_{path_variables}_"
                    f"{self.args.data_temperature}_{i}.npy"
                )
            msm_cluster_files.append(msm_cluster_file)
        return file_paths, msm_cluster_files

    def _validate_files(self, file_paths):
        """Ensure all trajectory files exist."""
        return all(os.path.exists(fp) for fp in file_paths)

    def _load_arrays_and_clusters(self, file_paths, msm_cluster_files):
        arr_paths = file_paths

        clusters_list = []
        for cf, fp in zip(msm_cluster_files, file_paths):
            if os.path.exists(cf):
                clusters_list.append(np.load(cf, mmap_mode=None))
            else:  # dummy 1-state assignment
                length = self._read_trajectory(fp).shape[0]
                clusters_list.append(np.zeros(length, dtype=np.int32))

        return arr_paths, clusters_list

    def _filter_single_state(self, arr_list, clusters_list, name):
        """
        Identify slices with only a single unique state.
        If msm_include_single_state is False, remove those slices.
        """
        single_state = [np.unique(clusters).shape[0] == 1 for clusters in clusters_list]
        if not self.args.msm_include_single_state:
            arr_list = [
                arr_list[i] for i in range(len(arr_list)) if not single_state[i]
            ]
            clusters_list = [
                clusters_list[i]
                for i in range(len(clusters_list))
                if not single_state[i]
            ]

        return arr_list, clusters_list, single_state

    def _build_transition_matrix(self, clusters_list):
        """Build the MSM transition matrix using deeptime and symmetrize it."""
        counts_estimator = deeptime.markov.TransitionCountEstimator(
            self.args.msm_lagtime, "sliding"
        )
        counts = counts_estimator.fit_fetch(clusters_list)
        msm = deeptime.markov.msm.MaximumLikelihoodMSM(
            allow_disconnected=True, reversible=True
        ).fit_fetch(counts)
        transition_matrix = msm.transition_matrix
        return (transition_matrix + transition_matrix.T) / 2

    def _compute_clusters_members(self, clusters_list, transition_matrix):
        """Compute indices (members) for each cluster in every trajectory slice."""
        num_clusters = transition_matrix.shape[0]
        cluster_members_list = []
        for clusters in clusters_list:
            cluster_members = [np.where(clusters == c)[0] for c in range(num_clusters)]
            cluster_members_list.append(cluster_members)
        return cluster_members_list

    def _compute_probabilities(self, transition_matrix, single_state):
        """Compute the sampling probabilities based on the chosen mode."""
        num_clusters = transition_matrix.shape[0]
        if (
            num_clusters == 1 or all(single_state)
        ) and self.args.msm_include_single_state:
            probabilities = transition_matrix.copy()
        elif self.cluster_sampling_mode == "transition":
            assert not self.args.msm_include_single_state
            probabilities = transition_matrix.copy()
            np.fill_diagonal(probabilities, 0)
        elif self.cluster_sampling_mode == "uniform":
            probabilities = np.ones((num_clusters, num_clusters), dtype=float)
            # if not self.args.msm_include_single_state:
            #     np.fill_diagonal(probabilities, 0)
        elif self.cluster_sampling_mode == "uniform_with_zeros":
            probabilities = np.ones((num_clusters, num_clusters), dtype=float)
            if not self.args.msm_include_single_state:
                np.fill_diagonal(probabilities, 0)
            probabilities[transition_matrix == 0] = 0
        elif self.cluster_sampling_mode == "original":
            probabilities = transition_matrix.copy()
        else:
            raise ValueError(
                f"Invalid cluster sampling mode: {self.cluster_sampling_mode}"
            )
        return probabilities

    def _fix_missing_clusters(self, probabilities, clusters_list):
        """Set probabilities to zero for any clusters missing from the data and return unique clusters."""
        num_clusters = probabilities.shape[0]
        all_clusters = np.concatenate(clusters_list)
        missing_clusters = np.setdiff1d(np.arange(num_clusters), all_clusters)
        if missing_clusters.size > 0:
            probabilities[missing_clusters, :] = 0
            probabilities[:, missing_clusters] = 0
        unique_clusters = np.unique(all_clusters)
        return probabilities, unique_clusters

    def __len__(self):
        return self.repeat * len(self.valid_names) if not self.args.overfit else 100

    def __getitem__(self, idx):
        if self.args.overfit:
            idx = 0

        name = (
            self.valid_names[idx % len(self.valid_names)]
            if not self.args.overfit_peptide
            else self.args.overfit_peptide
        )
        seqres = self.df.seqres[name]
        full_name = f"{name}_{self.args.data_temperature}"
        data_info = self.files[name]

        if self.args.msm_merge_replicas:
            arr, clusters, probabilities, cluster_members, unique_clusters = (
                self._merge_replicas(data_info)
            )
        else:
            arr, clusters, probabilities, cluster_members, unique_clusters = (
                self._select_replica(data_info, name)
            )

        if self.x0_sampling_mode == "cluster_based":
            x_t, x_t_plus_tau = self._sample_cluster_based(
                arr, probabilities, cluster_members, unique_clusters
            )
        elif self.x0_sampling_mode == "cluster_based_v2":
            if self.args.optimal_transport_mode in (None, "", "none"):
                x_t, x_t_plus_tau = self._sample_cluster_based_v2(
                    arr, clusters, probabilities, cluster_members, unique_clusters
                )
            else:
                x_t, x_t_plus_tau = self._sample_cluster_based_v2_OT(
                    arr, clusters, probabilities, cluster_members, unique_clusters
                )

        elif self.x0_sampling_mode == "sample_based":
            x_t, x_t_plus_tau = self._sample_sample_based(
                arr, clusters, probabilities, cluster_members
            )

        elif self.x0_sampling_mode == "uniform_frames":
            x_t, x_t_plus_tau = self._sample_uniform_frames(arr)

        else:
            raise ValueError(f"Invalid x0 sampling mode: {self.x0_sampling_mode}")

        return self._process_output(x_t, x_t_plus_tau, seqres, full_name)

    def _merge_replicas(self, data_info):
        """Merge replica data for proteins that require merging across replicas."""
        arr = np.concatenate(
            [self._read_trajectory(p) for p in data_info["arr_paths"]]  # NEW
        )
        clusters = np.concatenate(data_info["clusters"])
        if arr.shape[0] != clusters.shape[0]:
            raise ValueError("Array and clusters have different lengths")
        probabilities = data_info["probabilities"].copy()
        cluster_members = self._concat_replicas(data_info["cluster_members"])
        unique_clusters = data_info["unique_clusters"]
        return arr, clusters, probabilities, cluster_members, unique_clusters

    def _concat_replicas(self, replica_lists):
        num_states = len(replica_lists[0])
        max_index = 0
        concatenated_states = [np.array([], dtype=int) for _ in range(num_states)]

        for replica in replica_lists:
            for state_idx, state_array in enumerate(replica):
                adjusted_indices = state_array + max_index
                concatenated_states[state_idx] = np.concatenate(
                    (
                        concatenated_states[state_idx],
                        adjusted_indices,
                    )
                )
            non_empty_arrays = [arr for arr in replica if arr.size > 0]
            if non_empty_arrays:
                max_index += max(map(max, non_empty_arrays)) + 1
        return concatenated_states

    def _select_replica(self, data_info, name):
        """Select a valid replica from those with non-degenerate transitions."""
        if not self.args.msm_include_single_state:
            multi_indices = [
                i
                for i, clust in enumerate(data_info["clusters"])
                if np.unique(clust).shape[0] != 1
            ]
        else:
            multi_indices = np.arange(len(data_info["clusters"]))
        if len(multi_indices) == 0:
            raise ValueError(f"Protein {name} has no valid replica for MSM sampling")
        i_choice = np.random.choice(multi_indices)
        arr = self._read_trajectory(data_info["arr_paths"][i_choice])
        clusters = data_info["clusters"][i_choice]
        probabilities = data_info["probabilities"].copy()
        cluster_members = data_info["cluster_members"][i_choice]
        unique_clusters = data_info["unique_clusters"]

        # Zero-out probabilities corresponding to clusters not present in this replica
        missing_clusters = np.setdiff1d(np.arange(probabilities.shape[0]), clusters)
        probabilities[missing_clusters, :] = 0
        probabilities[:, missing_clusters] = 0
        unique_clusters = np.unique(clusters)
        return arr, clusters, probabilities, cluster_members, unique_clusters

    def _sample_uniform_frames(self, arr):
        """Pick every (x0, x1) completely uniformly from the same trajectory slice."""
        n = self.num_transitions_per_traj  # how many pairs you want
        max_t = arr.shape[0]  # total frames in this slice
        x_t_indices = np.random.randint(0, max_t, size=n)
        x_t_plus_tau_indices = np.random.randint(0, max_t, size=n)
        return (
            arr[x_t_indices].astype(np.float32),
            arr[x_t_plus_tau_indices].astype(np.float32),
        )

    def _sample_cluster_based(
        self, arr, probabilities, cluster_members, unique_clusters
    ):
        x0_clusters = np.random.choice(
            unique_clusters, size=self.num_clusters_to_sample, replace=True
        )

        x1_clusters = np.array(
            [
                np.random.choice(
                    np.arange(probabilities.shape[0]),
                    p=probabilities[x0] / probabilities[x0].sum(),
                )
                for x0 in x0_clusters
            ]
        )

        sampled_pairs = []
        for x0, x1 in zip(x0_clusters, x1_clusters):
            indices_x0 = cluster_members[x0]
            indices_x1 = cluster_members[x1]
            sampled_x0 = np.random.choice(
                indices_x0, size=self.num_samples_per_cluster, replace=True
            )
            sampled_x1 = np.random.choice(
                indices_x1, size=self.num_samples_per_cluster, replace=True
            )
            sampled_pairs.extend(zip(sampled_x0, sampled_x1))
        sampled_pairs = np.array(sampled_pairs)

        x_t_indices = sampled_pairs[:, 0]
        x_t_plus_tau_indices = sampled_pairs[:, 1]

        return arr[x_t_indices].astype(np.float32), arr[x_t_plus_tau_indices].astype(
            np.float32
        )

    def _sample_cluster_based_v2(
        self, arr, clusters, probabilities, cluster_members, unique_clusters
    ):
        first_cluster = clusters[0]
        if not self.args.tree_adjustment:
            if self.num_clusters_to_sample - 1 > 0:
                remaining_clusters = np.setdiff1d(
                    unique_clusters, np.array([first_cluster])
                )
                if remaining_clusters.size == 0:
                    additional_x0 = np.full(
                        self.num_clusters_to_sample - 1, first_cluster
                    )
                else:
                    additional_x0 = np.random.choice(
                        remaining_clusters,
                        size=self.num_clusters_to_sample - 1,
                        replace=True,
                    )
                x0_clusters = np.concatenate(([first_cluster], additional_x0))
            else:
                x0_clusters = np.array([first_cluster])
        else:
            x0_clusters = np.random.choice(
                unique_clusters, size=self.num_clusters_to_sample, replace=True
            )

        x1_clusters = np.array(
            [
                np.random.choice(
                    np.arange(probabilities.shape[0]),
                    size=self.num_samples_per_cluster,
                    replace=True,
                    p=probabilities[x0] / probabilities[x0].sum(),
                )
                for x0 in x0_clusters
            ]
        )
        x0_clusters = np.repeat(
            x0_clusters, self.num_samples_per_cluster
        )  # shape: (num_clusters_to_sample * num_samples_per_cluster,)
        x1_clusters = x1_clusters.flatten()

        x_t_indices = MDGenDataset.vectorized_sample(x0_clusters, cluster_members)
        x_t_plus_tau_indices = MDGenDataset.vectorized_sample(
            x1_clusters, cluster_members
        )

        return arr[x_t_indices].astype(np.float32), arr[x_t_plus_tau_indices].astype(
            np.float32
        )

    @staticmethod
    def vectorized_sample(cluster_array, cluster_members):
        """
        For an array of cluster ids, group indices by unique cluster id and
        sample random members from cluster_members accordingly.
        """
        sampled = np.empty(cluster_array.shape[0], dtype=int)
        unique_ids = np.unique(cluster_array)
        for cl in unique_ids:
            group_idxs = np.where(cluster_array == cl)[0]
            members = cluster_members[cl]
            rand_indices = np.random.randint(0, len(members), size=group_idxs.shape[0])
            sampled[group_idxs] = members[rand_indices]
        return sampled

    def _sample_cluster_based_v2_OT(
        self, arr, clusters, probabilities, cluster_members, unique_clusters
    ):
        # --- choose the x0 clusters (first one is always the first frame’s cluster) ----
        if not self.args.tree_adjustment:
            first_cluster = clusters[0]
            if self.num_clusters_to_sample - 1 > 0:
                remaining = np.setdiff1d(unique_clusters, [first_cluster])
                extra = np.random.choice(
                    remaining if remaining.size else [first_cluster],
                    size=self.num_clusters_to_sample - 1,
                    replace=True,
                )
                x0_clusters = np.concatenate(([first_cluster], extra))
            else:
                x0_clusters = np.array([first_cluster])
        else:
            x0_clusters = np.random.choice(
                unique_clusters, size=self.num_clusters_to_sample, replace=True
            )

        # ----------------------------------------------------------------------------- #
        # The *old* OT path: one (x0 → x1) pair per cluster, then align the
        # `num_samples_per_cluster` frames within that pair with OT.
        # ----------------------------------------------------------------------------- #

        sampled_pairs = []

        for x0 in x0_clusters:
            # pick a *single* destination cluster for this source cluster
            x1 = np.random.choice(
                np.arange(probabilities.shape[0]),
                p=probabilities[x0] / probabilities[x0].sum(),
            )

            # sample candidates inside each cluster
            idx_x0 = np.random.choice(
                cluster_members[x0], size=self.num_samples_per_cluster, replace=True
            )
            idx_x1 = np.random.choice(
                cluster_members[x1], size=self.num_samples_per_cluster, replace=True
            )

            # ------- optimal transport alignment -------------------------------------
            if self.args.optimal_transport_mode == "time":
                sel0, sel1 = OptimalTransportSolver("time").solve(idx_x0, idx_x1)

            elif self.args.optimal_transport_mode == "rmsd":
                bb_x0 = self.extract_backbone(arr[idx_x0])
                bb_x1 = self.extract_backbone(arr[idx_x1])
                sel0, sel1 = OptimalTransportSolver("rmsd").solve(bb_x0, bb_x1)

            else:
                raise ValueError(
                    f"Unknown optimal_transport_mode: {self.args.optimal_transport_mode}"
                )

            paired = list(zip(idx_x0[sel0], idx_x1[sel1]))
            sampled_pairs.extend(paired)
        # ----------------------------------------------------------------------------- #

        sampled_pairs = np.asarray(sampled_pairs, dtype=int)
        x_t_indices = sampled_pairs[:, 0]
        x_t_plus_tau_indices = sampled_pairs[:, 1]

        return (
            arr[x_t_indices].astype(np.float32),
            arr[x_t_plus_tau_indices].astype(np.float32),
        )

    def _sample_sample_based(self, arr, clusters, probabilities, cluster_members):
        max_t = arr.shape[0]
        t_values = np.random.choice(
            np.arange(max_t),
            size=self.num_transitions_per_traj,
            replace=True,
        )
        x_t = arr[t_values].astype(np.float32)

        x_t_clusters = clusters[t_values]

        x_t_plus_tau_indices = []
        for x_t_cluster in x_t_clusters:
            current_x1_cluster = np.random.choice(
                np.arange(probabilities.shape[0]),
                p=probabilities[x_t_cluster] / probabilities[x_t_cluster].sum(),
            )
            indices_x1 = cluster_members[current_x1_cluster]
            x_t_plus_tau_index = np.random.choice(indices_x1)
            x_t_plus_tau_indices.append(x_t_plus_tau_index)

        x_t_plus_tau_indices = [
            np.random.choice(
                cluster_members[
                    np.random.choice(
                        np.arange(probabilities.shape[0]),
                        p=(
                            probabilities[x_t_cluster]
                            / probabilities[x_t_cluster].sum()
                        ),
                    )
                ]
            )
            for x_t_cluster in x_t_clusters
        ]

        x_t_plus_tau = arr[x_t_plus_tau_indices].astype(np.float32)

        return x_t, x_t_plus_tau

    def _process_output(self, x_t, x_t_plus_tau, seqres, full_name):
        frames = atom14_to_frames(torch.from_numpy(x_t))
        frames_plus_tau = atom14_to_frames(torch.from_numpy(x_t_plus_tau))
        seqres = np.array([restype_order[c] for c in seqres])

        aatype = torch.from_numpy(seqres)[None].expand(
            self.num_transitions_per_traj, -1
        )

        atom37 = torch.from_numpy(atom14_to_atom37(x_t, aatype)).float()
        atom37_plus_tau = torch.from_numpy(
            atom14_to_atom37(x_t_plus_tau, aatype)
        ).float()
        L = frames.shape[1]
        mask = np.ones(L, dtype=np.float32)
        torsions, torsion_mask = atom37_to_torsions(atom37, aatype)
        torsions_plus_tau, _ = atom37_to_torsions(atom37_plus_tau, aatype)

        torsion_mask = torsion_mask[0]

        if L > self.args.crop:
            start = np.random.randint(0, L - self.args.crop + 1)
            torsions = torsions[:, start : start + self.args.crop]

            torsions_plus_tau = torsions_plus_tau[:, start : start + self.args.crop]
            frames = frames[:, start : start + self.args.crop]

            frames_plus_tau = frames_plus_tau[:, start : start + self.args.crop]
            seqres = seqres[start : start + self.args.crop]
            mask = mask[start : start + self.args.crop]
            torsion_mask = torsion_mask[start : start + self.args.crop]

        elif L < self.args.crop:
            pad = self.args.crop - L
            frames = Rigid.cat(
                [
                    frames,
                    Rigid.identity(
                        (self.num_transitions_per_traj, pad),
                        requires_grad=False,
                        fmt="rot_mat",
                    ),
                ],
                1,
            )
            frames_plus_tau = Rigid.cat(
                [
                    frames_plus_tau,
                    Rigid.identity(
                        (self.num_transitions_per_traj, pad),
                        requires_grad=False,
                        fmt="rot_mat",
                    ),
                ],
                1,
            )
            mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)])
            seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)])
            torsions = torch.cat(
                [
                    torsions,
                    torch.zeros((torsions.shape[0], pad, 7, 2), dtype=torch.float32),
                ],
                1,
            )
            torsions_plus_tau = torch.cat(
                [
                    torsions_plus_tau,
                    torch.zeros(
                        (torsions_plus_tau.shape[0], pad, 7, 2), dtype=torch.float32
                    ),
                ],
                1,
            )
            torsion_mask = torch.cat(
                [torsion_mask, torch.zeros((pad, 7), dtype=torch.float32)]
            )

        return {
            "name": full_name,
            "torsions": torsions,
            "torsions_plus_tau": torsions_plus_tau,
            "torsion_mask": torsion_mask,
            "trans": frames._trans,
            "trans_plus_tau": frames_plus_tau._trans,
            "rots": frames._rots._rot_mats,
            "rots_plus_tau": frames_plus_tau._rots._rot_mats,
            "seqres": seqres,
            "mask": mask,  # (L,)
        }

    def extract_backbone(self, atom14):
        """Return concatenated (N, CA, C) backbone coordinates – shape (N, L*3, 3)."""
        N_idx = rc.atom_order["N"]
        CA_idx = rc.atom_order["CA"]
        C_idx = rc.atom_order["C"]
        N = atom14[:, :, N_idx, :]
        CA = atom14[:, :, CA_idx, :]
        C = atom14[:, :, C_idx, :]
        return np.concatenate([N, CA, C], axis=1)

    @staticmethod
    def _read_trajectory(path, idx=None):
        mm = np.lib.format.open_memmap(path, mode="r")  # FD open
        try:
            view = mm if idx is None else mm[idx]
            arr = np.array(view, copy=True)  # force real ndarray
        finally:
            mm._mmap.close()  # FD closed
            del mm
        return arr
