import os
import os.path as osp
import json
import numpy as np
import random
import torch.utils.data as data
import torch
from torch_geometric.nn.pool import knn_graph
from src.chroma.data import Protein
from src.tools.design_utils import _dihedrals
from .bio_tokenizer import BioTokenizer


class PairDataset(data.Dataset):
    def __init__(self, data_path='./', split='train', data=None, min_length=30, max_length=500, k_neighbors=50, tm_thr=0.9,dump_features=True):
        self.__dict__.update(locals())
        self.data_dir = osp.dirname(self.data_path) if self.data_path.endswith('.jsonl') \
                    else self.data_path
        self.dump_features = dump_features
        self.tokenizer = BioTokenizer()

        if data is None:
            self.all_data = self.cache_pair_data()
            self.data = self.all_data[split]
        else:
            self.data = data
            
    def read_line_single_chain(self, line, min_length=30, max_length=1024):
        if isinstance(line, dict):
            return line

        entry = json.loads(line)
        if (min_length<=len(entry['seq'])) and (len(entry['seq']) <= max_length):
            data = {
                'title': entry['name'],
                'seq': torch.tensor(self.tokenizer.encode(entry['seq'])),
                'X': torch.tensor(entry['coords']),
                'afX': torch.tensor(entry['afdb_coords']),
                'chain_encoding': torch.ones(len(entry['seq'])),
                'mask': torch.tensor(entry['valid_mask']),
            }
            return data # if entry['tmscore'] > self.tm_thr else None

    def cache_pair_data(self):
        #data_dict = {'train': [], 'val': [], 'test': []}
        data_dict = {'train': [], 'val': [], 'test': []}
        for split in data_dict.keys():
            split_path = osp.join(self.data_path, split + '.jsonl')
            with open(split_path) as f:
                lines = f.readlines()
            data_dict[split] = lines
        return data_dict

    def _get_features(self, batch):
        X, afX, seq = batch['X'], batch['afX'], batch['seq']
        if self.dump_features:
            X_features = _dihedrals(X)
            afX_features = _dihedrals(afX) # cos(psi), omega,phi,sin... cos(alpha,beta,gamma),sin...
        mask = batch['mask']
        try:
            if (mask == False).any():
                X[~mask] = afX[~mask]
        except:
            return None

        batch_id = torch.zeros_like(seq)
        C_a = X[:,1,:]
        edge_idx = knn_graph(C_a, k=self.k_neighbors, batch=batch_id, loop=True, flow='target_to_source')
        batch = {
            'title': batch['title'],
            'X': X,
            'afX': afX,
            'S': seq,
            'C': batch['chain_encoding'],
            'node_idx': batch['node_idx'],
            'edge_idx': edge_idx,
            'batch_id': batch_id,
            'mask': mask,
            'num_nodes': torch.tensor(X.shape[0]).reshape(1,),
        }
        if self.dump_features:
            batch["features"]=X_features.squeeze()
            batch['af_features']=afX_features.squeeze()
        return batch
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        line = self.data[index]
        item = self.read_line_single_chain(line, self.min_length)
        if (item is None) or (item['X'].shape[0] < 30):
            return None

        L = len(item['seq'])
        item['node_idx'] = torch.arange(L)
        if L > self.max_length:
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['X'] = item['X'][truncate_index:truncate_index+self.max_length]
            item['afX'] = item['afX'][truncate_index:truncate_index+self.max_length]
            item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
        
        data = self._get_features(item)
        return data