import os
import os.path as osp
import json
import numpy as np
import random
import torch.utils.data as data
import torch
from torch_geometric.nn.pool import knn_graph
from src.chroma.data import Protein
from src.tools.design_utils import _dihedrals
from .bio_tokenizer import BioTokenizer


class PDBDataset(data.Dataset):
    def __init__(self, data_path='./', split='train', data=None, min_length=30, max_length=500, k_neighbors=50, tm_thr=0.9,dump_features=False):
        self.__dict__.update(locals())
        self.data_dir = self.data_path
        self.split = split
        self.valandtest_dir = "raw_data/pair_data"
        self.dump_features = dump_features
        self.tokenizer = BioTokenizer()

        if data is None:
            self.all_data = self.cache_pair_data()
            self.data = self.all_data[split]
        else:
            self.data = data
            
    def read_line_single_chain(self, line, min_length=30, max_length=1024):
        if isinstance(line, dict):
            return line

        entry = json.loads(line)
        if self.split == 'train':
            if (min_length<=len(entry['seq'])) and (len(entry['seq']) <= max_length):
                coords = entry['coords']
                data = {
                    'title': entry['name'],
                    'seq': torch.tensor(self.tokenizer.encode(entry['seq'])),
                    'X': torch.stack([torch.tensor(coords['N']), torch.tensor(coords['CA']), torch.tensor(coords['C']), torch.tensor(coords['O'])], dim=1),
                    "chain_encoding" : torch.cat([torch.full((length,), idx + 1, dtype=torch.int32) 
                                for idx, length in enumerate(entry['chain_length'].values())]),
                    'mask': torch.ones(len(entry['seq'])),
                }
                return data 
        else:
            if (min_length<=len(entry['seq'])) and (len(entry['seq']) <= max_length):
                data = {
                    'title': entry['name'],
                    'seq': torch.tensor(self.tokenizer.encode(entry['seq'])),
                    'X': torch.tensor(entry['coords']),
                    'chain_encoding': torch.ones(len(entry['seq'])),
                    'mask': torch.tensor(entry['valid_mask']),
                }
                return data if entry['tmscore'] > self.tm_thr else None

    def cache_pair_data(self):
        data_dict = {'train': [], 'val': [], 'test': []}
        for split in data_dict.keys():
            if split in ["val","test"]:
                split_path = osp.join(self.valandtest_dir, split + '.jsonl')
            else:
                split_path = self.data_dir
            with open(split_path) as f:
                lines = f.readlines()
            data_dict[split] = lines
        return data_dict

    def _get_features(self, batch):
        X, seq = batch['X'],  batch['seq']
        if self.dump_features:
            X_features = _dihedrals(X) # cos(psi), omega,phi,sin... cos(alpha,beta,gamma),sin...
        mask = batch['mask']

        batch_id = torch.zeros_like(seq)
        C_a = X[:,1,:]
        edge_idx = knn_graph(C_a, k=self.k_neighbors, batch=batch_id, loop=True, flow='target_to_source')
        batch = {
            'title': batch['title'],
            'X': X,
            'S': seq,
            'C': batch['chain_encoding'],
            'node_idx': batch['node_idx'],
            'edge_idx': edge_idx,
            'batch_id': batch_id,
            'mask': mask,
            'num_nodes': torch.tensor(X.shape[0]).reshape(1,),
        }
        if self.dump_features:
            batch["features"]=X_features.squeeze()
        return batch
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        line = self.data[index]
        item = self.read_line_single_chain(line, self.min_length)
        if (item is None) or (item['X'].shape[0] < 30) or (len(torch.nonzero(torch.isnan(item['X'])))!=0): # the last means no nan
            return None

        L = len(item['seq'])
        item['node_idx'] = torch.arange(L)
        if L > self.max_length:
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['X'] = item['X'][truncate_index:truncate_index+self.max_length]
            item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
        
        data = self._get_features(item)
        return data