import os
import os.path as osp
import json
import numpy as np
import random
import torch.utils.data as data
import torch
from torch_geometric.nn.pool import knn_graph
from src.chroma.data import Protein
from .bio_tokenizer import BioTokenizer


class AllTestDataset(data.Dataset):
    def __init__(self, data_path='./', split='train', jsonl_num=1,max_count=50000,min_length=30, max_length=500, k_neighbors=30):
        self.__dict__.update(locals())
        self.data_dir = osp.dirname(self.data_path) if self.data_path.endswith('.jsonl') \
                    else self.data_path
        
        self.tokenizer = BioTokenizer()
        self.data = []
        self.get_from_certain_jsons(self.data_dir,jsonl_num,max_count,min_length,max_length,k_neighbors)
        self.indicator = [0] * len(self.data)

    def get_from_certain_jsons(self,directory,num_files,max_count,min_length=30, max_length=500, k_neighbors=30):
        files = sorted([f for f in os.listdir(directory) if f.endswith('.jsonl')])[:num_files]
        counter=0
        for file_name in files:
            file_path = os.path.join(directory, file_name)
            print(f"Processing file: {file_path}")
            
            with open(file_path, 'r') as file:
                for line in file:
                    try:
                        entry = json.loads(line)
                        self.data.append(self.read_line_AFDB(entry,min_length, max_length, k_neighbors))
                        counter+=1
                        if counter>=max_count:
                            break
                    except json.JSONDecodeError as e:
                        print(f"Failed to parse JSON line: {line.strip()}")
                        print(f"Error: {e}")
                        continue

    def read_line_AFDB(self,entry,min_length=30, max_length=500, k_neighbors=30):
        seq = entry['seq']
        chain_encoding = torch.ones(len(seq))
        if min_length <= len(seq) <= max_length:
            N, CA, C, O = (np.array(entry[key]) for key in ['N', 'CA', 'C', 'O'])
            seq = torch.tensor(self.tokenizer.encode(seq))
            X = torch.from_numpy(np.stack([N, CA, C, O], axis=1)).float()
            mask = ~torch.isnan(X.sum(dim=(1,2)))
            return {
                'title': entry['title'],
                'seq': seq[mask],
                'X': X[mask],
                'C': chain_encoding[mask],
                'from_pdb': 0}

    def _get_features(self, batch):
            X, S = batch['X'], batch['seq']
            X, S = X.unsqueeze(0), S.unsqueeze(0)
            mask = torch.isfinite(torch.sum(X,(2,3))).float() # atom mask
            numbers = torch.sum(mask, axis=1).int()
            S_new = torch.zeros_like(S)
            X_new = torch.zeros_like(X) + torch.nan
            for i, n in enumerate(numbers):
                X_new[i,:n,::] = X[i][mask[i]==1]
                S_new[i,:n] = S[i][mask[i]==1]

            X, S = X_new, S_new
            isnan = torch.isnan(X)
            mask = torch.isfinite(torch.sum(X,(2,3))).float()
            X[isnan] = 0.

            mask_bool = (mask==1)
            def node_mask_select(x):
                shape = x.shape
                x = x.reshape(shape[0], shape[1],-1)
                out = torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1])
                out = out.reshape(-1,*shape[2:])
                return out

            batch_id = torch.arange(mask_bool.shape[0], device=mask_bool.device)[:,None].expand_as(mask_bool)
            X, seq = node_mask_select(X), node_mask_select(S)
            batch_id = node_mask_select(batch_id)
            mask = torch.masked_select(mask, mask_bool)

            C_a = X[:,1,:]
            edge_idx = knn_graph(C_a, k=self.k_neighbors, batch=batch_id, loop=True, flow='target_to_source')

            batch = {
                'title': batch['title'],
                'X': X,
                'S': seq,
                'C': batch['C'],
                'node_idx': batch['node_idx'],
                'edge_idx': edge_idx,
                'batch_id': batch_id,
                'mask': mask,
                'num_nodes': torch.tensor(X.shape[0]).reshape(1,),
                'from_pdb': batch['from_pdb']}
            return batch
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        line = self.data[index]
        item = line
        if (item is None) or (item['X'].shape[0] < 30):
            return None

        L = len(item['seq'])
        item['node_idx'] = torch.arange(L)
        if L > self.max_length:
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['X'] = item['X'][truncate_index:truncate_index+self.max_length]
            item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
        
        data = self._get_features(item)
        return data