import os
import os.path as osp
import json
import numpy as np
import random
import torch.utils.data as data
import torch
from torch_geometric.nn.pool import knn_graph
from src.chroma.data import Protein
from src.tools.design_utils import _dihedrals
from .bio_tokenizer import BioTokenizer


class FeaturePredictionDataset(data.Dataset):
    def __init__(self, data_path='./', split='train', data=None, min_length=30, max_length=2048, k_neighbors=50, tm_thr=0.9,dump_features=True):
        self.__dict__.update(locals())
        self.data_dir = self.data_path
        self.split = split
        self.valandtest_dir = "raw_data/pair_data"
        self.tokenizer = BioTokenizer()
        self.min_length=min_length
        self.max_length=max_length

        if data is None:
            self.all_data = self.cache_pair_data()
            self.data = self.all_data[split]
        else:
            self.data = data
            
    def read_line_single_chain(self, line, min_length=30, max_length=2048):
        if isinstance(line, dict):
            return line

        entry = json.loads(line)
        if (min_length<=len(entry['seq'])) and  (len(entry['seq'])<=max_length) : 
            data = {
                'title': entry['name'],
                'seq': torch.tensor(self.tokenizer.encode(entry['seq']))if isinstance(entry['seq'], str) else torch.tensor(entry['seq']),
                'X': torch.tensor(entry['coords']),
                'afX': torch.tensor(entry['afdb_coords']),
                'chain_encoding': torch.ones(len(entry['seq'])),
                'mask': torch.ones(len(entry['seq'])),
            }
            return data

    def cache_pair_data(self):
        data_dict = {'train': [], 'val': [], 'test': []}
        for split in data_dict.keys():
            if split in ["val","test"]:
                split_path = osp.join(self.valandtest_dir, split + '.jsonl')
            else:
                split_path = self.data_dir
            with open(split_path) as f:
                lines = f.readlines()
            data_dict[split] = lines
        return data_dict

    def _get_features(self, batch):
        X, seq = batch['X'],  batch['seq']
        if True:
            afX = batch['afX']
        if self.dump_features:
            X_features = _dihedrals(X) # cos(psi), omega,phi,sin... cos(alpha,beta,gamma),sin...
            afX_features = _dihedrals(afX)
        mask = batch['mask']
        batch_id = torch.zeros_like(seq)
        C_a = X[:,1,:]
        edge_idx = knn_graph(C_a, k=self.k_neighbors, batch=batch_id, loop=True, flow='target_to_source')
        batch = {
            'title': batch['title'],
            'X': X,
            'S': seq,
            'C': batch['chain_encoding'],
            'node_idx': batch['node_idx'],
            'edge_idx': edge_idx,
            'batch_id': batch_id,
            'mask': mask,
            'num_nodes': torch.tensor(X.shape[0]).reshape(1,),
        }
        if self.dump_features:
            batch["features"]=X_features.squeeze()
            batch["af_features"]=afX_features.squeeze()
        return batch
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        line = self.data[index]
        item = self.read_line_single_chain(line, self.min_length,self.max_length)
        if (item is None) or (item['X'].shape[0] < 30):
            return None

        L = len(item['seq'])
        item['node_idx'] = torch.arange(L)
        if L > self.max_length :
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['X'] = item['X'][truncate_index:truncate_index+self.max_length]
            item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
        
        data = self._get_features(item)
        return data