import torch
from .residue_constants import restype_order
import numpy as np
import pandas as pd
from .geometry import atom37_to_torsions, atom14_to_atom37, atom14_to_frames
from mdgen.optimal_transport import OptimalTransportSolver
from . import residue_constants as rc

import os
import deeptime
from contextlib import contextmanager


class MDGenDataset(torch.utils.data.Dataset):
    def __init__(self, args, split, repeat=1):
        super().__init__()
        self.df = pd.read_csv(split, index_col="name")
        self.args = args
        self.repeat = repeat
        self.tau = args.lag_time
        self.num_transitions_per_traj = args.num_samples_per_cluster * args.num_clusters
        self.num_samples_per_cluster = args.num_samples_per_cluster
        self.num_clusters_to_sample = args.num_clusters
        self.cluster_sampling_mode = args.cluster_sampling_mode
        self.x0_sampling_mode = args.x0_sampling_mode

        self.files = {}

        for i, name in enumerate(self.df.index):
            if self.args.overfit_peptide:
                name = self.args.overfit_peptide

            file_path, transition_file, msm_cluster_file = self._construct_file_paths(
                name, args.num_pcca_states
            )

            # check that all files exist
            if not (
                os.path.exists(file_path)
                and os.path.exists(transition_file)
                and os.path.exists(msm_cluster_file)
            ):
                print(f"Missing file for {name}")
                continue

            self.files[name] = self._load_data(
                file_path, msm_cluster_file, transition_file
            )

            if self.args.overfit_peptide:
                break

        if self.args.overfit_peptide:
            self.valid_names = [self.args.overfit_peptide]
        else:
            self.valid_names = list(self.files.keys())

    def _construct_file_paths(self, name, num_pcca_states):
        base = self.args.data_dir
        suffix = self.args.suffix
        if num_pcca_states == 10:
            file_path = f"{base}/{name}{suffix}.npy"
            transition_file = f"{base}/{name}_transition_matrix.npy"
            msm_cluster_file = f"{base}/{name}_msm_cluster.npy"
        else:
            file_path = f"{base}/{name}{suffix}.npy"
            transition_file = f"{base}/{name}_transition_matrix_{num_pcca_states}.npy"
            msm_cluster_file = f"{base}/{name}_msm_cluster_{num_pcca_states}.npy"
        return file_path, transition_file, msm_cluster_file

    def _load_data(self, file_path, msm_cluster_file, transition_file):
        arr_path = file_path
        clusters = np.load(msm_cluster_file).flatten()[::100]
        # transition_matrix = np.load(transition_file).copy()
        transition_matrix = self._build_transition_matrix(clusters)

        cluster_members = []
        cluster_members = [
            np.where(clusters == c)[0] for c in range(transition_matrix.shape[0])
        ]

        probabilities = self._get_sampling_probabilities(
            self.cluster_sampling_mode, transition_matrix
        )

        missing_clusters = np.setdiff1d(np.arange(probabilities.shape[0]), clusters)
        if missing_clusters.size > 0:
            probabilities[missing_clusters, :] = 0
            probabilities[:, missing_clusters] = 0

        unique_clusters = np.unique(clusters)

        return {
            "arr_path": arr_path,
            "clusters": clusters,
            "probabilities": probabilities,
            "cluster_members": cluster_members,
            "unique_clusters": unique_clusters,
        }

    def _build_transition_matrix(self, clusters_list):
        """Build the MSM transition matrix using deeptime and symmetrize it."""
        counts_estimator = deeptime.markov.TransitionCountEstimator(
            self.args.msm_lagtime, "sliding"
        )
        counts = counts_estimator.fit_fetch(clusters_list)
        msm = deeptime.markov.msm.MaximumLikelihoodMSM(
            allow_disconnected=True, reversible=True
        ).fit_fetch(counts)
        transition_matrix = msm.transition_matrix
        return (transition_matrix + transition_matrix.T) / 2

    def _get_sampling_probabilities(self, mode, transition_matrix):
        num_clusters = transition_matrix.shape[0]
        probabilities = np.zeros((num_clusters, num_clusters))

        if mode == "transition":
            probabilities = transition_matrix.copy()
            np.fill_diagonal(probabilities, 0)
        elif mode == "uniform":
            probabilities.fill(1)
            # np.fill_diagonal(probabilities, 0)
        elif mode == "uniform_with_zeros":
            probabilities.fill(1)
            np.fill_diagonal(probabilities, 0)
            probabilities[transition_matrix == 0] = 0
        elif mode == "original":
            probabilities = transition_matrix.copy()
        else:
            raise ValueError(f"Invalid sampling mode: {mode}")

        return probabilities

    def __len__(self):
        return self.repeat * len(self.valid_names) if not self.args.overfit else 100

    def __getitem__(self, idx):
        idx = idx % len(self.valid_names)
        if self.args.overfit:
            idx = 0

        name = self.valid_names[idx]
        # seqres = self.df.seqres[name]
        seqres = name  ##### DEBUG
        data_info = self.files[name]

        with MDGenDataset._open_memmap(data_info["arr_path"]) as arr_mm:
            arr = arr_mm  # mem-map view
            clusters = data_info["clusters"]
            probabilities = data_info["probabilities"]
            cluster_members = data_info["cluster_members"]
            unique_clusters = data_info["unique_clusters"]

            if self.x0_sampling_mode == "cluster_based":
                x_t, x_t_plus_tau = self._sample_cluster_based(
                    arr, probabilities, cluster_members, unique_clusters
                )

            elif self.x0_sampling_mode == "cluster_based_v2":
                x_t, x_t_plus_tau = self._sample_cluster_based_v2(
                    arr, clusters, probabilities, cluster_members, unique_clusters
                )

            elif self.x0_sampling_mode == "sample_based":
                x_t, x_t_plus_tau = self._sample_sample_based(
                    arr, clusters, probabilities, cluster_members
                )

            elif self.x0_sampling_mode == "sample_based_v2":
                x_t, x_t_plus_tau = self._sample_sample_based_v2(
                    arr, probabilities, cluster_members, unique_clusters
                )

            elif self.x0_sampling_mode == "uniform_frames":
                x_t, x_t_plus_tau = self._sample_uniform_frames(arr)

            else:
                raise ValueError(f"Invalid x0 sampling mode: {self.x0_sampling_mode}")

            frames = atom14_to_frames(torch.from_numpy(x_t))
            frames_plus_tau = atom14_to_frames(torch.from_numpy(x_t_plus_tau))
            seqres = np.array([restype_order[c] for c in seqres])
            aatype = torch.from_numpy(seqres)[None].expand(
                self.num_transitions_per_traj, -1
            )

            atom37 = torch.from_numpy(atom14_to_atom37(x_t, aatype)).float()
            atom37_plus_tau = torch.from_numpy(
                atom14_to_atom37(x_t_plus_tau, aatype)
            ).float()

            L = frames.shape[1]
            mask = np.ones(L, dtype=np.float32)
            torsions, torsion_mask = atom37_to_torsions(atom37, aatype)
            torsions_plus_tau, _ = atom37_to_torsions(atom37_plus_tau, aatype)
            torsion_mask = torsion_mask[0]

            return {
                "name": name,
                "torsions": torsions,
                "torsions_plus_tau": torsions_plus_tau,
                "torsion_mask": torsion_mask,
                "trans": frames._trans,
                "trans_plus_tau": frames_plus_tau._trans,
                "rots": frames._rots._rot_mats,
                "rots_plus_tau": frames_plus_tau._rots._rot_mats,
                "seqres": seqres,
                "mask": mask,  # (L,)
            }

    def _sample_cluster_based(
        self, arr, probabilities, cluster_members, unique_clusters
    ):
        x0_clusters = np.random.choice(
            unique_clusters, size=self.num_clusters_to_sample, replace=True
        )

        x1_clusters = np.array(
            [
                np.random.choice(
                    np.arange(probabilities.shape[0]),
                    p=probabilities[x0] / probabilities[x0].sum(),
                )
                for x0 in x0_clusters
            ]
        )

        sampled_pairs = []
        for x0, x1 in zip(x0_clusters, x1_clusters):
            indices_x0 = cluster_members[x0]
            indices_x1 = cluster_members[x1]
            indices_x0 = np.random.choice(
                indices_x0, size=self.num_samples_per_cluster, replace=True
            )
            indices_x1 = np.random.choice(
                indices_x1, size=self.num_samples_per_cluster, replace=True
            )

            if self.args.optimal_transport_mode == "time":
                idx0, idx1 = OptimalTransportSolver(
                    mode=self.args.optimal_transport_mode
                ).solve(indices_x0, indices_x1)
                paired_indices = list(
                    zip(
                        np.array(indices_x0)[np.array(idx0)],
                        np.array(indices_x1)[np.array(idx1)],
                    )
                )

            elif self.args.optimal_transport_mode == "rmsd":
                backbone_x0 = self.extract_backbone(arr[indices_x0])
                backbone_x1 = self.extract_backbone(arr[indices_x1])
                idx0, idx1 = OptimalTransportSolver(
                    mode=self.args.optimal_transport_mode
                ).solve(backbone_x0, backbone_x1)
                paired_indices = list(
                    zip(
                        np.array(indices_x0)[np.array(idx0)],
                        np.array(indices_x1)[np.array(idx1)],
                    )
                )

            sampled_pairs.extend(paired_indices)

        sampled_pairs = np.array(sampled_pairs)

        x_t_indices = sampled_pairs[:, 0]
        x_t_plus_tau_indices = sampled_pairs[:, 1]

        return arr[x_t_indices].astype(np.float32), arr[x_t_plus_tau_indices].astype(
            np.float32
        )

    def _sample_cluster_based_v2(
        self, arr, clusters, probabilities, cluster_members, unique_clusters
    ):
        first_cluster = clusters[0]
        if self.num_clusters_to_sample - 1 > 0:
            remaining_clusters = np.setdiff1d(
                unique_clusters, np.array([first_cluster])
            )
            if remaining_clusters.size == 0:
                additional_x0 = np.full(self.num_clusters_to_sample - 1, first_cluster)
            else:
                additional_x0 = np.random.choice(
                    remaining_clusters,
                    size=self.num_clusters_to_sample - 1,
                    replace=True,
                )
            x0_clusters = np.concatenate(([first_cluster], additional_x0))
        else:
            x0_clusters = np.array([first_cluster])

        x1_clusters = np.array(
            [
                np.random.choice(
                    np.arange(probabilities.shape[0]),
                    size=self.num_samples_per_cluster,
                    replace=True,
                    p=probabilities[x0] / probabilities[x0].sum(),
                )
                for x0 in x0_clusters
            ]
        )
        x0_clusters = np.repeat(x0_clusters, self.num_samples_per_cluster)
        x1_clusters = x1_clusters.flatten()

        x_t_indices = MDGenDataset.vectorized_sample(x0_clusters, cluster_members)
        x_t_plus_tau_indices = MDGenDataset.vectorized_sample(
            x1_clusters, cluster_members
        )

        return arr[x_t_indices].astype(np.float32), arr[x_t_plus_tau_indices].astype(
            np.float32
        )

    @staticmethod
    def vectorized_sample(cluster_array, cluster_members):
        """
        For an array of cluster ids, group indices by unique cluster id and
        sample random members from cluster_members accordingly.
        """
        sampled = np.empty(cluster_array.shape[0], dtype=int)
        unique_ids = np.unique(cluster_array)
        for cl in unique_ids:
            group_idxs = np.where(cluster_array == cl)[0]
            members = cluster_members[cl]
            rand_indices = np.random.randint(0, len(members), size=group_idxs.shape[0])
            sampled[group_idxs] = members[rand_indices]
        return sampled

    def _sample_uniform_frames(self, arr):
        """Pick every (x0, x1) completely uniformly from the same trajectory slice."""
        n = self.num_transitions_per_traj  # how many pairs you want
        max_t = arr.shape[0]  # total frames in this slice
        x_t_indices = np.random.randint(0, max_t, size=n)
        x_t_plus_tau_indices = np.random.randint(0, max_t, size=n)
        return (
            arr[x_t_indices].astype(np.float32),
            arr[x_t_plus_tau_indices].astype(np.float32),
        )

    def _sample_sample_based_v2(
        self, arr, probabilities, cluster_members, unique_clusters
    ):
        cluster_weights = np.array(
            [len(cluster_members[clust]) for clust in unique_clusters]
        )

        x0_clusters = np.random.choice(
            unique_clusters,
            size=self.num_clusters_to_sample,
            replace=True,
            p=cluster_weights / cluster_weights.sum(),
        )

        x1_clusters = np.array(
            [
                np.random.choice(
                    np.arange(probabilities.shape[0]),
                    p=probabilities[x0] / probabilities[x0].sum(),
                )
                for x0 in x0_clusters
            ]
        )

        sampled_pairs = []
        for x0, x1 in zip(x0_clusters, x1_clusters):
            indices_x0 = cluster_members[x0]
            indices_x1 = cluster_members[x1]
            indices_x0 = np.random.choice(
                indices_x0, size=self.num_samples_per_cluster, replace=True
            )
            indices_x1 = np.random.choice(
                indices_x1, size=self.num_samples_per_cluster, replace=True
            )

            if self.args.optimal_transport_mode == "time":
                idx0, idx1 = OptimalTransportSolver(
                    mode=self.args.optimal_transport_mode
                ).solve(indices_x0, indices_x1)
                paired_indices = list(
                    zip(
                        np.array(indices_x0)[np.array(idx0)],
                        np.array(indices_x1)[np.array(idx1)],
                    )
                )

            elif self.args.optimal_transport_mode == "rmsd":
                backbone_x0 = self.extract_backbone(arr[indices_x0])
                backbone_x1 = self.extract_backbone(arr[indices_x1])
                idx0, idx1 = OptimalTransportSolver(
                    mode=self.args.optimal_transport_mode
                ).solve(backbone_x0, backbone_x1)
                paired_indices = list(
                    zip(
                        np.array(indices_x0)[np.array(idx0)],
                        np.array(indices_x1)[np.array(idx1)],
                    )
                )

            sampled_pairs.extend(paired_indices)

        sampled_pairs = np.array(sampled_pairs)

        x_t_indices = sampled_pairs[:, 0]
        x_t_plus_tau_indices = sampled_pairs[:, 1]

        return arr[x_t_indices].astype(np.float32), arr[x_t_plus_tau_indices].astype(
            np.float32
        )

    def _sample_sample_based(self, arr, clusters, probabilities, cluster_members):
        max_t = arr.shape[0]
        t_values = np.random.choice(
            np.arange(max_t),
            size=min(self.num_transitions_per_traj, max_t),
            replace=False,
        )
        x_t_clusters = clusters[t_values]

        x_t_plus_tau_indices = []
        for x_t_cluster in x_t_clusters:
            current_x1_cluster = np.random.choice(
                np.arange(probabilities.shape[0]),
                p=probabilities[x_t_cluster] / probabilities[x_t_cluster].sum(),
            )
            indices_x1 = cluster_members[current_x1_cluster]
            x_t_plus_tau_index = np.random.choice(indices_x1)
            x_t_plus_tau_indices.append(x_t_plus_tau_index)

        x_t_plus_tau_indices = [
            np.random.choice(
                cluster_members[
                    np.random.choice(
                        np.arange(probabilities.shape[0]),
                        p=(
                            probabilities[x_t_cluster]
                            / probabilities[x_t_cluster].sum()
                        ),
                    )
                ]
            )
            for x_t_cluster in x_t_clusters
        ]

        if self.args.optimal_transport_mode == "time":
            t_values = np.array(t_values)
            x_t_plus_tau_indices = np.array(x_t_plus_tau_indices)
            unique_clusters = np.unique(x_t_clusters)
            for clust in unique_clusters:
                idx = np.where(x_t_clusters == clust)[0]
                if idx.size > 0:
                    group_x0 = t_values[idx]
                    group_x1 = x_t_plus_tau_indices[idx]
                    idx0, idx1 = OptimalTransportSolver(
                        mode=self.args.optimal_transport_mode
                    ).solve(group_x0, group_x1)
                    t_values[idx] = group_x0[idx0]
                    x_t_plus_tau_indices[idx] = group_x1[idx1]

        elif self.args.optimal_transport_mode == "rmsd":
            t_values = np.array(t_values)
            x_t_plus_tau_indices = np.array(x_t_plus_tau_indices)
            unique_clusters = np.unique(x_t_clusters)
            for clust in unique_clusters:
                idx = np.where(x_t_clusters == clust)[0]
                if idx.size > 0:
                    group_x0 = self.extract_backbone(arr[t_values[idx]])
                    group_x1 = self.extract_backbone(arr[x_t_plus_tau_indices[idx]])
                    idx0, idx1 = OptimalTransportSolver(
                        mode=self.args.optimal_transport_mode
                    ).solve(group_x0, group_x1)
                    t_values[idx] = t_values[idx][idx0]
                    x_t_plus_tau_indices[idx] = x_t_plus_tau_indices[idx][idx1]

        x_t = arr[t_values].astype(np.float32)
        x_t_plus_tau = arr[x_t_plus_tau_indices].astype(np.float32)

        return x_t, x_t_plus_tau

    def extract_backbone(self, atom14):
        """
        Extract backbone atom coordinates from an atom14 array using residue constants.

        Parameters:
        atom14: np.ndarray of shape (num_samples, L, 14, 3)
                where L is the number of residues.

        Returns:
        backbone: np.ndarray of shape (num_samples, L*3, 3)
                    The backbone is constructed by concatenating N, CA, and C coordinates.
        """
        # Use residue_constants to get the proper indices
        N_idx = rc.atom_order["N"]
        CA_idx = rc.atom_order["CA"]
        C_idx = rc.atom_order["C"]

        # Extract coordinates for each backbone atom type
        N_coords = atom14[:, :, N_idx, :]
        CA_coords = atom14[:, :, CA_idx, :]
        C_coords = atom14[:, :, C_idx, :]

        # Concatenate the backbone coordinates along the residue dimension
        backbone = np.concatenate([N_coords, CA_coords, C_coords], axis=1)
        return backbone

    @staticmethod
    @contextmanager
    def _open_memmap(path):
        mm = np.lib.format.open_memmap(path, mode="r")  # FD open
        try:
            yield mm  # give to caller
        finally:
            mm._mmap.close()  # FD closed

    @staticmethod
    def _read_trajectory(path, idx=None):
        with MDGenDataset._open_memmap(path) as mm:
            view = mm if idx is None else mm[idx]
            return np.array(view, copy=True)
