import torch
from .rigid_utils import Rigid
from .residue_constants import restype_order, restype_atom37_mask
from .geometry import atom37_to_torsions, atom14_to_atom37, atom14_to_frames
import numpy as np
import pandas as pd

class MDGenDataset(torch.utils.data.Dataset):
    def __init__(self, args, split, repeat=1):
        super().__init__()
        self.df = pd.read_csv(split, index_col='name')
        self.args = args
        self.repeat = repeat
        

        valid_names = []
        invalid_names =[]
        for name in self.df.index:
            try:
                full_name = f"{name}_{0}"
                arr = np.lib.format.open_memmap(f'{self.args.data_dir}/{full_name}_{self.args.suffix}.npy', 'r')
                if arr.shape[0] >= 400:
                    valid_names.append(name)
                else:
                    invalid_names.append({
                        'name': name,
                        'frames': arr.shape[0]  
                    })
            except:
                continue
            
        self.df = self.df.loc[valid_names]
        if invalid_names:
            invalid_df = pd.DataFrame(invalid_names)
            invalid_df.to_csv('invalid_frames.csv', index=False)
        print(f"Dataset filtered: {len(valid_names)} valid trajectories out of {len(self.df)}")
        
    def __len__(self):
        return self.repeat * len(self.df)

    def __getitem__(self, idx):
        idx = idx % len(self.df)
        if self.args.overfit:
            idx = 0

        name = self.df.index[idx]
        seqres = self.df.seqres[name]
        full_name = f"{name}_{0}"
        print(full_name)
        
        arr = np.lib.format.open_memmap(f'{self.args.data_dir}/{full_name}_{self.args.suffix}.npy', 'r')
        total_frames = arr.shape[0]

        if self.args.frame_interval:
            frame_start = 0
            end = min(400, total_frames)
            arr = arr[frame_start:end]
        else:
            frame_start = self.args.coarse_start
            end = frame_start + self.args.num_frames
            arr = arr[frame_start:end]
            
            ########################### multi_resolution setting ##################
            # frame_start_coarse = int(self.args.threshold*(frame_start//self.args.num_frames))
            # arr_coarse = np.lib.format.open_memmap(f'{self.args.data_dir_coarse}/{full_name}.npy', 'r')
            # arr_1st = arr_coarse[frame_start_coarse]
            # arr_2nd = arr_coarse[frame_start_coarse+1]
            # arr = np.copy(arr).astype(np.float32)
            # arr[0] = arr_1st
            # arr[1] = arr_2nd
            ########################## multi_resolution setting ###################
            

        print('the pred start frame is : ', frame_start)
        if self.args.frame_interval:
            if self.args.threshold != 1:
                indices = np.concatenate([np.arange(i, i+2) for i in range(0, arr.shape[0], self.args.frame_interval )])
                arr = arr[indices]  
            else:
                arr = arr[::self.args.frame_interval]
            
        arr = np.copy(arr).astype(np.float32)
        

        indices_to_zero = [i for i in range(14) if i not in [0, 1, 2, 3]]
        arr[:, :, indices_to_zero, :] = 0
        
        if self.args.copy_frames:
            arr[1:] = arr[0]

        # arr should be in ANGSTROMS
        frames = atom14_to_frames(torch.from_numpy(arr))
        seqres = np.array([restype_order[c] for c in seqres])
        aatype = torch.from_numpy(seqres)[None].expand(self.args.num_frames, -1)
        atom37 = torch.from_numpy(atom14_to_atom37(arr, aatype)).float()
        
        L = frames.shape[1]
        mask = np.ones(L, dtype=np.float32)
        
        if self.args.no_frames:
            return {
                'name': full_name,
                'frame_start': frame_start,
                'atom37': atom37,
                'seqres': seqres,
                'mask': restype_atom37_mask[seqres], 
            }
            
        torsions, torsion_mask = atom37_to_torsions(atom37, aatype)
        
        torsion_mask = torsion_mask[0]

        if L > self.args.crop:
            start = np.random.randint(0, L - self.args.crop + 1)
            torsions = torsions[:,start:start+self.args.crop]
            frames = frames[:,start:start+self.args.crop]
            seqres = seqres[start:start+self.args.crop]
            mask = mask[start:start+self.args.crop]
            torsion_mask = torsion_mask[start:start+self.args.crop]
            arr = arr[:, start:start+self.args.crop]
                
        elif L < self.args.crop:
            pad = self.args.crop - L
            frames = Rigid.cat([
                frames, 
                Rigid.identity((self.args.num_frames, pad), requires_grad=False, fmt='rot_mat')
            ], 1)
            mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)])
            seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)])
            torsions = torch.cat([torsions, torch.zeros((torsions.shape[0], pad, 7, 2), dtype=torch.float32)], 1)
            torsion_mask = torch.cat([torsion_mask, torch.zeros((pad, 7), dtype=torch.float32)])
                
            arr = np.concatenate([
                arr, 
                np.zeros((arr.shape[0], pad, arr.shape[2], arr.shape[3]), dtype=arr.dtype)
            ], axis=1)

        return {
            'name': full_name,
            'frame_start': frame_start,
            'torsions': torsions,
            'torsion_mask': torsion_mask,
            'trans': frames._trans,
            'rots': frames._rots._rot_mats,
            'seqres': seqres,
            'mask': mask, 
            'arr': arr,
            }
        