import os
import os.path as osp
import json
import numpy as np
import random
import torch.utils.data as data
import torch
from torch_geometric.nn.pool import knn_graph
from src.chroma.data import Protein
from .bio_tokenizer import BioTokenizer


class MixDataset(data.Dataset):
    def __init__(self, data_path='./', split='train', pdb_data=None, afdb_data=None, min_length=30, max_length=500, k_neighbors=30, remove_pdb=False):
        self.__dict__.update(locals())
        self.data_dir = osp.dirname(self.data_path) if self.data_path.endswith('.jsonl') \
                    else self.data_path
        
        self.tokenizer = BioTokenizer()
        if pdb_data is None and afdb_data is None:
            if not osp.exists(data_path):
                raise FileNotFoundError(f"No such file: {data_path}.")

            if remove_pdb is False:
                self.pdb_data = self.cache_pdb_data()
                self.pdb_data['test'] = self.cache_test_pdb_files()
            else:
                self.pdb_data = {'train': [], 'val': [], 'test': []}

            self.afdb_data = self.cache_afdb_data()
            self.afdb_data['test'] = self.cache_test_afdb_files()
        else:
            # split
            self.data = pdb_data + afdb_data
            self.indicator = [1] * len(pdb_data) + [0] * len(afdb_data)

    def read_line_AFDB(self, line, min_length=30):
        try:
            entry = json.loads(line)
        except json.JSONDecodeError as e:
            print(f"Failed to parse JSON line: {line.strip()}")
            print(f"Error: {e}")
            return None
        
        seq = entry['seq']
        chain_encoding = torch.ones(len(seq))
        if (min_length <= len(entry['seq'])) and (len(entry['seq']) <= self.max_length):
            N, CA, C, O = (np.array(entry[key]) for key in ['N', 'CA', 'C', 'O'])
            seq = torch.tensor(self.tokenizer.encode(entry['seq']))
            X = torch.from_numpy(np.stack([N, CA, C, O], axis=1)).float()
            mask = ~torch.isnan(X.sum(dim=(1,2)))
            return {
                'title': entry['title'],
                'seq': seq[mask],
                'X': X[mask],
                'C': chain_encoding[mask],
                'from_pdb': 0
            }
            
    def read_line_single_chain(self, line, min_length=30, max_length=1024):
        if isinstance(line, dict):
            return line

        entry = json.loads(line)
        seq = entry['seq']

        for key, val in entry['coords'].items():
            entry['coords'][key] = np.asarray(val)

        chain_list = []
        start = 0
        for i, L in enumerate(entry['chain_length'].values()):
            N, CA, C, O = (entry['coords'][key][start: start+L] for key in ['N', 'CA', 'C', 'O'])
            seq = torch.tensor(self.tokenizer.encode(entry['seq'][start: start+L]))
            X = torch.from_numpy(np.stack([N, CA, C, O], axis=1)).float()
            if X.shape[0] == 0:
                continue
            chain_encoding = torch.ones_like(seq)
            mask = ~torch.isnan(X.sum(dim=(1,2)))
            data = {
                'title': entry['name'].split('.')[0]+f'_{i}',
                'seq': seq[mask],
                'X': X[mask],
                'C': chain_encoding[mask],
                'from_pdb': 1
            }
            
            start += L
            if (min_length<=len(data['seq'])) and (len(data['seq']) <= max_length):
                chain_list.append(data)
        try:
            return random.choice(chain_list)
        except Exception as e:
            return None

    def cache_pdb_data(self, val_num=512):
        data_dict = {'train': [], 'val': [], 'test': []}
        pdb_file = self.data_path if self.data_path.endswith('.jsonl') \
                        else osp.join(self.data_path, 'pdb.jsonl')
        with open(pdb_file) as f:
            lines = f.readlines()

        num_samples = min(int(0.1 * len(lines)), val_num)
        data_dict['train'] = lines[num_samples:]
        data_dict['val'] = lines[:num_samples]
        return data_dict
    
    def cache_afdb_data(self, val_num=512):
        data_dict = {'train': [], 'val': [], 'test': []}
        afdb_dir = osp.join(self.data_dir, 'afdb')

        files = [f for f in os.listdir(afdb_dir) if f.endswith('.jsonl')]
        if not files:
            raise FileNotFoundError("No .jsonl files found in the directory.")
        
        selected_file = random.choice(files)
        afdb_file = osp.join(afdb_dir, selected_file)
        with open(afdb_file) as f:
            lines = f.readlines()

        num_samples = min(int(0.1 * len(lines)), val_num)
        data_dict['train'] = lines[num_samples:]
        data_dict['val'] = lines[:num_samples]
        return data_dict
    
    def cache_test_pdb_files(self, testset='N128'):
        pdb_files_path = osp.join(self.data_dir, testset)
        if not osp.exists(pdb_files_path):
            raise FileNotFoundError(f"No such file: {pdb_files_path}.") 
        
        pdb_data = []
        for cur_pdb in os.listdir(pdb_files_path):
            protein = Protein.from_PDB(osp.join(pdb_files_path, cur_pdb))
            X, C, S = protein.to_XCS()
            chain_length = S.shape[-1]
            chain_encoding = torch.zeros(chain_length)

            X, seq = X[0], S[0]
            mask = ~torch.isnan(X.sum(dim=(1,2)))
            pdb_data.append({
                'title': cur_pdb.split('/')[-1],
                'seq': seq[mask],
                'X': X[mask],
                'C': chain_encoding[mask],
                'from_pdb': 1
            })
        return pdb_data
    
    def cache_test_afdb_files(self):
        afdb_files_path = osp.join(self.data_dir, 'afdb_test.jsonl')
        if not osp.exists(afdb_files_path):
            raise FileNotFoundError(f"No such file: {afdb_files_path}.") 
        with open(afdb_files_path) as f:
            lines = f.readlines()
        return random.sample(lines, 128)

    def _get_features(self, batch):
        X, S = batch['X'], batch['seq']
        X, S = X.unsqueeze(0), S.unsqueeze(0)
        mask = torch.isfinite(torch.sum(X,(2,3))).float() # atom mask
        numbers = torch.sum(mask, axis=1).int()
        S_new = torch.zeros_like(S)
        X_new = torch.zeros_like(X) + torch.nan
        for i, n in enumerate(numbers):
            X_new[i,:n,::] = X[i][mask[i]==1]
            S_new[i,:n] = S[i][mask[i]==1]

        X, S = X_new, S_new
        isnan = torch.isnan(X)
        mask = torch.isfinite(torch.sum(X,(2,3))).float()
        X[isnan] = 0.

        mask_bool = (mask==1)
        def node_mask_select(x):
            shape = x.shape
            x = x.reshape(shape[0], shape[1],-1)
            out = torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1])
            out = out.reshape(-1,*shape[2:])
            return out

        batch_id = torch.arange(mask_bool.shape[0], device=mask_bool.device)[:,None].expand_as(mask_bool)
        X, seq = node_mask_select(X), node_mask_select(S)
        batch_id = node_mask_select(batch_id)
        mask = torch.masked_select(mask, mask_bool)

        C_a = X[:,1,:]
        edge_idx = knn_graph(C_a, k=self.k_neighbors, batch=batch_id, loop=True, flow='target_to_source')

        batch = {
            'title': batch['title'],
            'X': X,
            'S': seq,
            'C': batch['C'],
            'node_idx': batch['node_idx'],
            'edge_idx': edge_idx,
            'batch_id': batch_id,
            'mask': mask,
            'num_nodes': torch.tensor(X.shape[0]).reshape(1,),
            'from_pdb': batch['from_pdb']}
        return batch

    def _update_afdb(self):
        afdb_data = self.cache_afdb_data()[self.split]
        pdb_data = [data for data, indicator in zip(self.data, self.indicator) if indicator == 1]
        # update data and indicator
        self.data = pdb_data + afdb_data
        self.indicator = [1] * len(pdb_data) + [0] * len(afdb_data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        line = self.data[index]
        item = self.read_line_single_chain(line, self.min_length) if self.indicator[index] == 1 else self.read_line_AFDB(line, self.min_length)
        if (item is None) or (item['X'].shape[0] < 30):
            return None

        L = len(item['seq'])
        item['node_idx'] = torch.arange(L)
        if L > self.max_length:
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['X'] = item['X'][truncate_index:truncate_index+self.max_length]
            item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
        
        data = self._get_features(item)
        return data