import torch
from torch.utils.data import Dataset, DataLoader

import biotite
from biotite.structure import get_residue_starts, residue_iter
from biotite.structure.residues import get_residues
from biotite.sequence import ProteinSequence

import os
import random
import numpy as np
import pandas as pd
import pickle

class ProteinDataset(Dataset):
    def __init__(self, args):

        prot_df = pd.read_csv(args.recon_quality_file)
        latent_distribution_stats = pd.read_csv(args.latent_stats_file)
        prot_df = pd.merge(prot_df, latent_distribution_stats, on='file_name', how='left')
        seq_len_filter = (prot_df['seq_len_x'] == prot_df['seq_len_y'])
        filtered_df = prot_df[seq_len_filter & (prot_df['rmsd'] < 1) & (prot_df['tm'] > 0.9)]
        self.prots = filtered_df['file_name'].values

        self.root_dir = args.latent_dir
        self.crop_longer_prot = args.crop_longer_prot

        print(f'DataLoader initialized. There are {len(self.prots)} clusters in total.')

    def __len__(self):
        return len(self.prots)

    def __getitem__(self, idx):

        prot_id = self.prots[idx]
        prot_chain_dict = dict(np.load(f'{self.root_dir}/{prot_id}.npz'))

        seq_len = prot_chain_dict['aatype'].shape[0]
        if seq_len > 512:
            prot_chain_dict = self.random_crop(prot_chain_dict, 512)

        if self.crop_longer_prot:
            seq_len = prot_chain_dict['aatype'].shape[0]
            if (seq_len > 60) & (np.random.rand() < 0.5):
                further_crop_len = np.random.randint(60, seq_len)
                prot_chain_dict = self.random_crop(prot_chain_dict, further_crop_len)
        
        min_residue_index = prot_chain_dict['residue_index'].min()
        prot_chain_dict['residue_index'] -= min_residue_index
        # the keys: aatype, residue_index, latent, confidence
        return prot_chain_dict
    
    def random_crop(self, prot_chain_dict, crop_len):

        ori_seq_len = prot_chain_dict['aatype'].shape[0]
        random_start = np.random.randint(0, ori_seq_len - crop_len + 1) # [start, end)
        return {
            'aatype': prot_chain_dict['aatype'][random_start:random_start+crop_len],
            'residue_index': prot_chain_dict['residue_index'][random_start:random_start+crop_len],
            'latent': prot_chain_dict['latent'][random_start:random_start+crop_len],
            # 'confidence': prot_chain_dict['confidence'][random_start:random_start+crop_len]
        }


class ConstructCollater():
    def __init__(self, tokenizer):

        self.restype_with_x = np.array(['A', 'R', 'N', 'D', 'C', 'Q', 'E', \
                                        'G', 'H', 'I', 'L', 'K', 'M', 'F', \
                                        'P', 'S', 'T', 'W', 'Y', 'V', 'X'])
        self.tokenizer = tokenizer

    def get_seq_from_aatype(self, aatype):
        return ''.join(self.restype_with_x[aatype])

    def __call__(self, raw_batch):
        
        seqs = []
        for prot_chain_dict in raw_batch:
            seq = self.get_seq_from_aatype(prot_chain_dict['aatype'])
            seqs.append(seq)
        
        # add cls and eos tokens to each sequence
        # pad the sequences to the longest one
        # the paded tokens will be masked out (recorded as 0 in the attention mask)
        batch = self.tokenizer.batch_encode_plus(seqs, 
                                                 add_special_tokens=True,
                                                 padding="longest",
                                                 return_tensors='pt')
        
        struct_latent = torch.zeros((len(raw_batch), batch['input_ids'].shape[1], 20), dtype=torch.float32)
        # struct_confidence = torch.zeros_like(batch['attention_mask'], dtype=torch.float32)
        residue_index = torch.zeros_like(batch['attention_mask'], dtype=torch.long)
        for i, prot_chain_dict in enumerate(raw_batch):
            struct_latent[i, 1:prot_chain_dict['latent'].shape[0]+1, :] = torch.tensor(prot_chain_dict['latent'], dtype=torch.float32)
            residue_index[i, 1:prot_chain_dict['residue_index'].shape[0]+1] = torch.tensor(prot_chain_dict['residue_index'], dtype=torch.long)
        
        batch = {
                'input_ids':  batch['input_ids'],
                'input_mask': batch['attention_mask'].bool(),
                'struct_latent': struct_latent.mul_(0.1875),
                # 'struct_confidence': struct_confidence,
                'residue_index': residue_index
            }

        # keys: input_ids, input_mask, struct_latent, struct_confidence, residue_index
        return batch
