import sys
import os

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(project_root)
import copy
import math
import random
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset
import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data.datapipes.map import SequenceWrapper
from torch.utils.data.dataset import Subset

from src.datamodules.datasets.data_utils import Alphabet
from src.datamodules.datasets.frame import Frame, Rotation

import esm
from sklearn.model_selection import train_test_split


from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import one_to_index, index_to_one
from src.datamodules.common_utils.transforms import get_transform
from src.datamodules.common_utils.protein.parsers import parse_biopython_structure
from src.datamodules.common_utils.protein.constants import resindex_to_ressymb

def reset_residue_idx(res_nb):
    reset_points = (res_nb == 1).nonzero(as_tuple=True)[0][1:]
    offsets = torch.zeros_like(res_nb)
    offsets[reset_points] = 100 + res_nb[reset_points-1]
    offsets = torch.cumsum(offsets, dim=0)
    return res_nb + offsets

def sample_pref_pair(pref_pairs):
    group_a, group_b = [], []
    for item in pref_pairs:
        if len(item) == 4:
            if item[1] >= 0.2:
                group_a.append((item[0], item[1]))
            elif item[1] <= 0.2:
                group_b.append((item[0], item[1]))
        elif len(item) == 3:
            if item[2] == 1:
                group_a.append((item[0], item[2]))
            elif item[2] == 0:
                group_b.append((item[0], item[2]))   
        else:
            if item[1] <= 0:
                group_a.append((item[0], -item[1]))
            if item[1] >= -0.05:
                group_b.append((item[0], -item[1]))

    if not group_a:
        print("No valid group_a found in the dataset.", group_b)
    if not group_b:
        print("No valid group_b found in the dataset.", group_a)
    if not group_a or not group_b:
        print("No valid pairs found in the dataset.")
        return None
    seq_a, ddg_a = random.choice(group_a)
    group_b_valid = [item for item in group_b if item[1] <= ddg_a]
    if not group_b_valid:
        print("No group_b_valid.")
        return None
    seq_b, ddg_b = random.choice(group_b_valid)
    return seq_a, ddg_a, seq_b, ddg_b

class GymDataset(Dataset):

    def __init__(self, cache_dir, split='train', num_cvfolds=3, cvfold_index=0):
        super().__init__()
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)

        self.split = split
        self.num_cvfolds = num_cvfolds
        self.cvfold_index = cvfold_index
        self.data_type = 'all'

        self.structures_cache = os.path.join(cache_dir, 'structures.pkl')
        self.structures = None
        self._load_structures()

        self.pairs_cache = os.path.join(cache_dir, 'skempi.pkl')
        self.entry_pairs = None
        self._load_pairs()

    def _load_structures(self):
        with open(self.structures_cache, 'rb') as f:
            self.structures = pickle.load(f)

    def _load_pairs(self):
        with open(self.pairs_cache, 'rb') as f:
            all_entry_pairs = pickle.load(f)
        filtered_entry_pairs = [entry for entry in all_entry_pairs if entry['pdb_name'] != '1N8Z_hm']
        filtered_count = len(all_entry_pairs) - len(filtered_entry_pairs)
        dms_entries = [entry for entry in filtered_entry_pairs if entry['score_label'].lower() == 'dms']
        ddg_entries = [entry for entry in filtered_entry_pairs if entry['score_label'].lower() == 'ddg']
        def group_by_complex(entries):
            complex_to_entries = {}
            for entry in entries:
                complex_name = entry['pdb_name']
                if complex_name not in complex_to_entries:
                    complex_to_entries[complex_name] = []
                complex_to_entries[complex_name].append(entry)
            return complex_to_entries
        
        dms_complex_to_entries = group_by_complex(dms_entries)
        ddg_complex_to_entries = group_by_complex(ddg_entries)

        dms_complex_list = sorted(dms_complex_to_entries.keys())
        ddg_complex_list = sorted(ddg_complex_to_entries.keys())

        random.Random(3745754758).shuffle(dms_complex_list)
        random.Random(3745754758).shuffle(ddg_complex_list)
        
        def split_complexes(complex_list, num_folds):
            split_size = math.ceil(len(complex_list) / num_folds)
            return [
                complex_list[i*split_size : (i+1)*split_size] 
                for i in range(num_folds)
            ]
        
        dms_complex_splits = split_complexes(dms_complex_list, self.num_cvfolds)
        ddg_complex_splits = split_complexes(ddg_complex_list, self.num_cvfolds)
        
        dms_val_split = dms_complex_splits.pop(self.cvfold_index) if dms_complex_splits else []
        ddg_val_split = ddg_complex_splits.pop(self.cvfold_index) if ddg_complex_splits else []
        
        dms_train_split = sum(dms_complex_splits, start=[]) if self.num_cvfolds > 1 else dms_val_split
        ddg_train_split = sum(ddg_complex_splits, start=[]) if self.num_cvfolds > 1 else ddg_val_split

        if self.num_cvfolds == 1:
            val_special = ['3VR6_ABCDEF_GH']
            for special in val_special:
                if special in dms_train_split:
                    dms_train_split.remove(special)
                    if special not in dms_val_split:
                        dms_val_split.append(special)
                elif special in ddg_train_split:
                    ddg_train_split.remove(special)
                    if special not in ddg_val_split:
                        ddg_val_split.append(special)

        for large_complex in ['3VR6_ABCDEF_GH']:
            if large_complex in dms_train_split:
                dms_train_split.remove(large_complex)
            if large_complex in ddg_train_split:
                ddg_train_split.remove(large_complex)

        if self.split == 'val':
            dms_complexes_this = dms_val_split
            ddg_complexes_this = ddg_val_split
        elif self.split == 'all':
            dms_complexes_this = dms_complex_list
            ddg_complexes_this = ddg_complex_list
        else:  # train
            dms_complexes_this = dms_train_split
            ddg_complexes_this = ddg_train_split

        data_type = getattr(self, 'data_type', 'all')

        if data_type == 'dms':
            complexes_this = dms_complexes_this
            complex_to_entries = dms_complex_to_entries
        elif data_type == 'ddg':
            complexes_this = ddg_complexes_this
            complex_to_entries = ddg_complex_to_entries
        else:  # 'all'
            complexes_this = dms_complexes_this + ddg_complexes_this
            complex_to_entries = {**dms_complex_to_entries, **ddg_complex_to_entries}

        selected_entries = []
        for complex_name in complexes_this:
            if complex_name in complex_to_entries:
                selected_entries.extend(complex_to_entries[complex_name])

        self.entry_pairs = selected_entries
        

    def __len__(self):
        return len(self.entry_pairs)

    def __getitem__(self, index):
        entry = self.entry_pairs[index]
        name = entry['pdb_name']
        data, seq_map = self.structures[name]
        merged = {**data, **entry}
        return merged

class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, truncation_seq_length: int = None):
        self.alphabet = alphabet
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
        batch_size = len(raw_batch)
        batch_labels, seq_str_list, seq_mut_str_list, chain_nb_list, mask_list = zip(*raw_batch)
        seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
        seq_mut_encoded_list = [self.alphabet.encode(seq_mut_str) for seq_mut_str in seq_mut_str_list]
        if self.truncation_seq_length:
            seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
            seq_mut_encoded_list = [seq_mut_str[:self.truncation_seq_length] for seq_mut_str in seq_mut_encoded_list]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        tokens = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        tokens_mut = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens_mut.fill_(self.alphabet.padding_idx)
        chain_nb = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        chain_nb.fill_(0)

        mask = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        mask.fill_(0)


        labels = []
        strs, strs_mut = [], []

        for i, (label, seq_str, seq_mut_str, seq_encoded, seq_mut_encoded, chain_nb1, mask1) in enumerate(
            zip(batch_labels, seq_str_list, seq_mut_str_list, seq_encoded_list, seq_mut_encoded_list, chain_nb_list, mask_list)
        ):
            labels.append(label)
            strs.append(seq_str)
            strs_mut.append(seq_mut_str)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            seq_mut = torch.tensor(seq_mut_encoded, dtype=torch.int64)
            tokens[
                i, int(self.alphabet.prepend_bos) : len(seq_encoded) + int(self.alphabet.prepend_bos),] = seq
            tokens_mut[
                i, int(self.alphabet.prepend_bos) : len(seq_mut_encoded) + int(self.alphabet.prepend_bos),] = seq_mut
            chain_nb[
                i, int(self.alphabet.prepend_bos) : len(chain_nb1) + int(self.alphabet.prepend_bos),] = chain_nb1
            mask[
                i, int(self.alphabet.prepend_bos) : len(mask1) + int(self.alphabet.prepend_bos),] = mask1
            if self.alphabet.append_eos:
                tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
                tokens_mut[i, len(seq_mut_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
        return labels, tokens, tokens_mut, chain_nb, mask
    
class CoordBatchConverter(BatchConverter):
    def __init__(self, alphabet, coord_pad_inf=False, coord_nan_to_zero=True, to_pifold_format=False):
        super().__init__(alphabet)
        self.coord_pad_inf = coord_pad_inf
        self.to_pifold_format = to_pifold_format
        self.coord_nan_to_zero = coord_nan_to_zero

    def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
        """
        Args:
            raw_batch: List of tuples (coords, confidence, seq)
            In each tuple,
                coords: list of floats, shape L x n_atoms x 3
                confidence: list of floats, shape L; or scalar float; or None
                seq: string of length L
        Returns:
            coords: Tensor of shape batch_size x L x n_atoms x 3
            confidence: Tensor of shape batch_size x L
            strs: list of strings
            tokens: LongTensor of shape batch_size x L
            padding_mask: ByteTensor of shape batch_size x L
        """
        # self.alphabet.cls_idx = self.alphabet.get_idx("<cath>")
        batch = []
        for coords, confidence, seq, seq_mut, chain_nb, mask in raw_batch:
            # print(coords)
            if confidence is None:
                confidence = 1.
            if isinstance(confidence, float) or isinstance(confidence, int):
                confidence = [float(confidence)] * len(coords)
            batch.append(((coords, confidence), seq, seq_mut, chain_nb, mask))

        coords_and_confidence, tokens, tokens_mut, chain_nb, mask = super().__call__(batch)
        if self.coord_pad_inf:
            # pad beginning and end of each protein due to legacy reasons
            coords = [
                F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.nan)
                for cd, _ in coords_and_confidence
            ]
            confidence = [
                F.pad(torch.tensor(cf), (1, 1), value=-1.)
                for _, cf in coords_and_confidence
            ]
        else:
            coords = [
                torch.tensor(cd) for cd, _ in coords_and_confidence
            ]
            confidence = [
                torch.tensor(cf) for _, cf in coords_and_confidence
            ]
        coords = self.collate_dense_tensors(coords, pad_v=np.nan)
        confidence = self.collate_dense_tensors(confidence, pad_v=-1.)

        if self.to_pifold_format:
            coords, tokens, tokens_mut, confidence = ToPiFoldFormat(X=coords, S=tokens, M=tokens_mut, cfd=confidence)

        lengths = tokens.ne(self.alphabet.padding_idx).sum(1).long()
        if device is not None:
            coords = coords.to(device)
            confidence = confidence.to(device)
            tokens = tokens.to(device)
            tokens_mut = tokens_mut.to(device)
            lengths = lengths.to(device)

        coord_padding_mask = torch.isnan(coords[:, :, 0, 0])
        coord_mask = torch.isfinite(coords.sum([-2, -1]))

        confidence = confidence * coord_mask + (-1.) * coord_padding_mask

        if self.coord_nan_to_zero:
            coords[torch.isnan(coords)] = 0.

        # compute frames based on backbones
        gt_frames = Frame.from_3_points(
                p_neg_x_axis=coords[..., 2, :], # C
                origin=coords[..., 1, :], # CA
                p_xy_plane=coords[..., 0, :], # N
            )
            
        rots = torch.eye(3)
        rots[0, 0] = -1
        rots[2, 2] = -1
        rots = Rotation(mat=rots)
        
        gt_frames = gt_frames.compose(Frame(rots, None))
        gt_frames_tensor = gt_frames.to_tensor_4x4()
        gt_frames_tensor[~coord_mask] = 0
        
        return coords, gt_frames_tensor, confidence, tokens, tokens_mut, lengths, coord_mask, chain_nb, mask

    def from_lists(self, coords_list, confidence_list=None, seq_list=None, seq_mut_list=None,  chain_nb_list=None, mask_list=None, device=None):
        """
        Args:
            coords_list: list of length batch_size, each item is a list of
            floats in shape L x 3 x 3 to describe a backbone
            confidence_list: one of
                - None, default to highest confidence
                - list of length batch_size, each item is a scalar
                - list of length batch_size, each item is a list of floats of
                    length L to describe the confidence scores for the backbone
                    with values between 0. and 1.
            seq_list: either None or a list of strings
        Returns:
            coords: Tensor of shape batch_size x L x 3 x 3
            confidence: Tensor of shape batch_size x L
            strs: list of strings
            tokens: LongTensor of shape batch_size x L
            padding_mask: ByteTensor of shape batch_size x L
        """
        batch_size = len(coords_list)
        if confidence_list is None:
            confidence_list = [None] * batch_size
        if seq_list is None:
            seq_list = [None] * batch_size
            seq_mut_list = [None] * batch_size
        if chain_nb_list is None:
            chain_nb_list = [None] * batch_size
        if mask_list is None:
            mask_list = [None] * batch_size
        raw_batch = zip(coords_list, confidence_list, seq_list, seq_mut_list, chain_nb_list, mask_list)
        return self.__call__(raw_batch, device)

    @staticmethod
    def collate_dense_tensors(samples, pad_v):
        """
        Takes a list of tensors with the following dimensions:
            [(d_11,       ...,           d_1K),
             (d_21,       ...,           d_2K),
             ...,
             (d_N1,       ...,           d_NK)]
        and stack + pads them into a single tensor of:
        (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
        """
        if len(samples) == 0:
            return torch.Tensor()
        if len(set(x.dim() for x in samples)) != 1:
            raise RuntimeError(
                f"Samples has varying dimensions: {[x.dim() for x in samples]}"
            )
        (device,) = tuple(set(x.device for x in samples))  # assumes all on same device
        max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
        result = torch.empty(
            len(samples), *max_shape, dtype=samples[0].dtype, device=device
        )
        result.fill_(pad_v)
        for i in range(len(samples)):
            result_i = result[i]
            t = samples[i]
            result_i[tuple(slice(0, k) for k in t.shape)] = t
        return result

def create_mask_excluding_mut_chains(chain_encoding_all, mut_chains):
    mask = torch.ones_like(chain_encoding_all, dtype=torch.bool)
    for chain_idx in mut_chains:
        mask = mask & (chain_encoding_all != chain_idx)
    return mask

class Featurizer(object):
    def __init__(self, alphabet: Alphabet, 
                 to_pifold_format=False, 
                 coord_nan_to_zero=True,
                 atoms=('N', 'CA', 'C', 'O')):
        self.alphabet = alphabet
        self.batcher = CoordBatchConverter(
            alphabet=alphabet,
            coord_pad_inf=alphabet.add_special_tokens,
            to_pifold_format=to_pifold_format, 
            coord_nan_to_zero=coord_nan_to_zero
        )

        self.atoms = atoms

    def __call__(self, raw_batch: dict):
        winners, losers, coords, names, num_muts, chain_nbs, chain_ids, masks,  = [], [], [], [], [], [], [], []
        winner_scores, loser_scores = [], []
        num_mut_chains = []
        score_labels = []
        mut_chains_list = [] 
        for entry in raw_batch:
            if isinstance(entry['pos_heavyatom'], dict):
                tr_coors = np.stack([entry['pos_heavyatom'][atom] for atom in self.atoms], 1)
            else:
                tr_coors = entry['pos_heavyatom']
            mpnn_data = {
                'X': tr_coors[...,:4,:],
                'winner': entry['winner'], 
                'loser': entry['loser'],
                'winner_score': entry['winner_score'], 
                'loser_score': entry['loser_score'],
                'mask': torch.ones_like(entry['aa']),
                'chain_M': torch.tensor(entry['mut_flag']).bool(),
                'chain_encoding_all': entry['chain_nb']+1,
                'residue_idx': reset_residue_idx(entry['res_nb']),
                'complex': entry['entry_name'],
                'name': entry['pdb_name'],
                'chain_id': entry['chain_id'],
                'score_labels': entry['score_label'],
            }
            coords.append(mpnn_data['X'])
            winners.append(mpnn_data['winner'])
            losers.append(mpnn_data['loser'])
            names.append(mpnn_data['name'])
            chain_nbs.append(mpnn_data['chain_encoding_all'])
            chain_ids.append(mpnn_data['chain_id'])
            winner_scores.append(mpnn_data['winner_score'])
            loser_scores.append(mpnn_data['loser_score'])
            masks.append(mpnn_data['mask'])
            score_labels.append(mpnn_data['score_labels'])
            
            mut_indices = torch.nonzero(mpnn_data['chain_M'], as_tuple=True)[0]
            mut_chains = torch.unique(mpnn_data['chain_encoding_all'][mut_indices], dim=0)
            num_mut_chains.append(mut_chains.size(0))
            mut_chains_list.append(mut_chains.tolist())

            for chain_idx in mut_chains.tolist():
                chain_mask = mpnn_data['chain_encoding_all'] == chain_idx
                
                single_chain_mpnn_data_padded = copy.deepcopy(mpnn_data)
                single_chain_mpnn_data_padded['mask'] = chain_mask
                single_chain_mpnn_data_padded['chain_M'] = chain_mask * mpnn_data['chain_M']
                single_chain_mpnn_data_padded['X'] = np.where(chain_mask[:, None, None], single_chain_mpnn_data_padded['X'], np.nan)
                single_chain_mpnn_data_padded['winner'] = ''.join(a if m else 'X' for a, m in zip(mpnn_data['winner'], chain_mask))
                single_chain_mpnn_data_padded['loser'] = ''.join(a if m else 'X' for a, m in zip(mpnn_data['loser'], chain_mask))
                
                coords.append(single_chain_mpnn_data_padded['X'])
                winners.append(single_chain_mpnn_data_padded['winner'])
                losers.append(single_chain_mpnn_data_padded['loser'])
                names.append(single_chain_mpnn_data_padded['name'])
                chain_nbs.append(single_chain_mpnn_data_padded['chain_encoding_all'])
                chain_ids.append(single_chain_mpnn_data_padded['chain_id'])
                winner_scores.append(single_chain_mpnn_data_padded['winner_score'])
                loser_scores.append(single_chain_mpnn_data_padded['loser_score'])
                masks.append(single_chain_mpnn_data_padded['mask'])
                score_labels.append(single_chain_mpnn_data_padded['score_labels'])

        coords, gt_frames_tensor, confidence, winners, losers, lengths, coord_mask, chain_nb, mask = self.batcher.from_lists(
            coords_list=coords, confidence_list=None, seq_list=winners, seq_mut_list=losers, chain_nb_list=chain_nbs, mask_list=masks
        )
        batch = {
            'coords': coords,
            'gt_frames_tensor': gt_frames_tensor,
            'tokens': losers,
            'tokens_mut': winners,
            'confidence': confidence,
            'coord_mask': coord_mask,
            'lengths': lengths,
            'names': names,
            'winner_scores': torch.tensor(winner_scores).to(coords.device),
            'loser_scores': torch.tensor(loser_scores).to(coords.device),
            'chain_nb': chain_nb,
            'mask': mask,
            'num_mut_chains': num_mut_chains,
            'score_labels': score_labels,
        }
        return batch

def ToPiFoldFormat(X, S, M, cfd, pad_special_tokens=False):
    mask = torch.isfinite(torch.sum(X, [-2, -1]))  # atom mask
    numbers = torch.sum(mask, dim=1).long()

    S_new = torch.zeros_like(S)
    M_new = torch.zeros_like(M)
    X_new = torch.zeros_like(X) + np.nan
    cfd_new = torch.zeros_like(cfd)

    for i, n in enumerate(numbers):
        X_new[i, :n] = X[i][mask[i] == 1]
        M_new[i, :n] = M[i][mask[i] == 1]
        S_new[i, :n] = S[i][mask[i] == 1]
        cfd_new[i, :n] = cfd[i][mask[i] == 1]

    X = X_new
    S = S_new
    M = M_new
    cfd = cfd_new

    return X, S, M, cfd_new