import torch

from .rigid_utils import Rigid
from .residue_constants import restype_order, restype_atom37_mask
from .geometry import atom37_to_torsions, atom14_to_atom37, atom14_to_frames
import numpy as np
import pandas as pd

       
class MDGenDataset(torch.utils.data.Dataset):
    def __init__(self, args, split, repeat=1):
        super().__init__()
        self.df = pd.read_csv(split, index_col='name')
        self.args = args
        self.repeat = repeat
        
    def __len__(self):
        return self.repeat * len(self.df)

    def __getitem__(self, idx):
        idx = idx % len(self.df)
        if self.args.overfit:
            idx = 0

        name = self.df.index[idx]
        seqres = self.df.seqres[name]
        
        i = np.random.randint(0,3) 
        full_name = f"{name}_{0}"

        
        seed_folder = f"test_{0}"
        arr = np.lib.format.open_memmap(f'{self.args.data_dir}/{seed_folder}/{full_name}_{self.args.suffix}.npy', 'r')


        if self.args.frame_interval:
            total_frames = arr.shape[0]  

            if total_frames < 400:

                repeats = 400 // total_frames + 1
 
                arr = np.tile(arr, (repeats, 1, 1, 1))     
                
            arr = arr[:400]
            frame_start = 0
            end = arr.shape[0] 
            
        else:
            ######## random frame start###########
            # frame_start = np.random.choice(np.arange(arr.shape[0] - self.args.num_frames)) # random start
            frame_start = np.random.choice(list(range(0, arr.shape[0]-self.args.num_frames, self.args.num_frames)))
            end = frame_start + self.args.num_frames
        
        
        if self.args.overfit_frame:
            frame_start = 0
        
        arr = arr[frame_start:end]
        
        if self.args.frame_interval:
            if self.args.threshold != 1:
                indices = np.concatenate([np.arange(i, i+2) for i in range(0, arr.shape[0] , self.args.frame_interval)])
                arr = arr[indices]  
            else:
                arr = arr[::self.args.frame_interval]
        
        arr = np.copy(arr).astype(np.float32) # / 10.0 # convert to nm

   
        ########## get atom 4 #############
        indices_to_zero = [i for i in range(14) if i not in [0, 1, 2, 3]]
        arr[:, :, indices_to_zero, :] = 0
       
        if self.args.copy_frames:
            arr[1:] = arr[0]

        # arr should be in ANGSTROMS
        frames = atom14_to_frames(torch.from_numpy(arr))
        seqres = np.array([restype_order[c] for c in seqres])

        
        aatype = torch.from_numpy(seqres)[None].expand(self.args.num_frames, -1)
        atom37 = torch.from_numpy(atom14_to_atom37(arr, aatype)).float()
        
        L = frames.shape[1]
        mask = np.ones(L, dtype=np.float32)
        

        torsions, torsion_mask = atom37_to_torsions(atom37, aatype)
        
        torsion_mask = torsion_mask[0]
        

        if L > self.args.crop:
            start = np.random.randint(0, L - self.args.crop + 1)
            torsions = torsions[:,start:start+self.args.crop]
            frames = frames[:,start:start+self.args.crop]
            seqres = seqres[start:start+self.args.crop]
            mask = mask[start:start+self.args.crop]
            torsion_mask = torsion_mask[start:start+self.args.crop]
            arr = arr[:, start:start+self.args.crop]
                
            
        elif L < self.args.crop:
            pad = self.args.crop - L
            frames = Rigid.cat([
                frames, 
                Rigid.identity((self.args.num_frames, pad), requires_grad=False, fmt='rot_mat')
            ], 1)
            mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)])
            seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)])
            torsions = torch.cat([torsions, torch.zeros((torsions.shape[0], pad, 7, 2), dtype=torch.float32)], 1)
            torsion_mask = torch.cat([torsion_mask, torch.zeros((pad, 7), dtype=torch.float32)])
                
            arr = np.concatenate([
                arr, 
                np.zeros((arr.shape[0], pad, arr.shape[2], arr.shape[3]), dtype=arr.dtype)
            ], axis=1)
                

        return {
            'name': full_name,
            'frame_start': frame_start,
            'torsions': torsions,
            'torsion_mask': torsion_mask,
            'trans': frames._trans,
            'rots': frames._rots._rot_mats,
            'seqres': seqres,
            'mask': mask, # (L,)，
            'arr': arr,
        }
