import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import pickle
from Bio.PDB import PDBParser, Select, PDBIO
from biotite.structure import AtomArray, filter_amino_acids, filter_nucleotides
from alphabet import Alphabet
from util import CoordBatchConverter, extract_coords_from_structure, load_structure

def filter_standard_bases(structure):

    standard_bases = {'A', 'G', 'C', 'U'}
    resnames = structure.res_name
    mask = np.isin(resnames, list(standard_bases))
    filtered_structure = structure[mask]
    return filtered_structure

class ProteinRNADataset(Dataset):
    def __init__(self, pdb_dir, csv_file, cache_file):

        self.pdb_dir = pdb_dir
        self.csv_file = csv_file
        self.cache_file = cache_file
        data = pd.read_csv(csv_file)
        self.chain_info = data

        with open(self.cache_file, 'rb') as f:
            self.protein_embeddings = pickle.load(f)
        self.pdb_files = [f for f in os.listdir(pdb_dir) if f.endswith(".pdb")]
        self.pdb_parser = PDBParser(QUIET=True)

    def __len__(self):
        return len(self.chain_info)

    def __getitem__(self, idx):
        row = self.chain_info.iloc[idx]
        pdb_id = row['PDB']
        protein_chains = row['Protein chains']
        rna_chains = row['RNA chains']

        pdb_filename = f"{pdb_id}.pdb"
        pdb_path = os.path.join(self.pdb_dir, pdb_filename)
        structure = load_structure(pdb_path, rna_chains[0])
        structure = filter_standard_bases(structure)
        protein_embedding = self.protein_embeddings[pdb_filename]
        rna_coords, rna_sequence = extract_coords_from_structure(structure)
        return {
            'rna_coords': rna_coords,    
            'rna_sequence': rna_sequence,  
            'protein_embedding_local': protein_embedding['local'],  
            'protein_embedding_global': protein_embedding['global'],  
            'dist_0_8': protein_embedding['dist_0_8'],
            'dist_8_15': protein_embedding['dist_8_15'],
            'dist_15_inf': protein_embedding['dist_15_inf']
        }
    
def collate_batch(batch):
    rna_coords = [item['rna_coords'] for item in batch]
    rna_sequences = [item['rna_sequence'] for item in batch]
    protein_embeddings_local = [item['protein_embedding_local'] for item in batch]
    protein_embeddings_global = [item['protein_embedding_global'] for item in batch]
    protein_dis_1 = [item['dist_0_8'] for item in batch]
    protein_dis_2 = [item['dist_8_15'] for item in batch]
    protein_dis_3 = [item['dist_15_inf'] for item in batch]
    
    max_len = 64
    batch_size = len(protein_embeddings_local)

    # Create padded tensor and padding mask
    padded_embeddings = torch.zeros((batch_size, max_len, protein_embeddings_global[0].shape[-1]))
    padding_masks = torch.zeros((batch_size, max_len), dtype=torch.bool)

    # Fill padded tensor and mask
    for i, emb in enumerate(protein_embeddings_local):
        length = emb.shape[1]
        padded_embeddings[i, :length] = emb.squeeze(0)
        padding_masks[i, :length] = True
        padding_masks[i, -3:] = False
    # Stack global embeddings
    protein_embeddings_global = torch.stack(protein_embeddings_global)
    protein_dis_1 = torch.stack(protein_dis_1)
    protein_dis_2 = torch.stack(protein_dis_2)
    protein_dis_3 = torch.stack(protein_dis_3)
    padded_embeddings[:,-3] = protein_dis_1
    padded_embeddings[:, -2] = protein_dis_2
    padded_embeddings[:, -1] = protein_dis_3
    return rna_coords, rna_sequences, padded_embeddings, protein_embeddings_global, padding_masks