import torch
import pandas as pd
import numpy as np
import re
from constants import *
from TCR import TCR
from pMHC import pMHC

class TCRpMHCdataset(torch.utils.data.Dataset):
    """
    Define a Meaningful TCRpMHCdataset Class. Designed for creation of a dataset designed for learning
    the joint latent space of TCR and pMHC. Should accept either TCR -> multiple pMHC mapping or 
    pMHC -> multiple TCR mapping.

    Args:
        source (str): The source of the dataset. Either 'tcr' or 'pmhc'.
        target (str): The target of the dataset. Either 'tcr' or 'pmhc'.
        tokenizer (Tokenizer): The tokenizer to use for the dataset.
        use_pseudo (bool): Whether to use the pseudo MHC sequence or the full MHC sequence.
        use_cdr3 (bool): Whether to use the CDR3 sequence or the full TCR sequence.
        bert_format (bool): Whether to use the BERT format for the sequences.
    
    Attributes:
        source (str): The source of the dataset. Either 'tcr' or 'pmhc'.
        target (str): The target of the dataset. Either 'tcr' or 'pmhc'.
        tcrs (list): The list of TCRs in the dataset.
        pMHCs (list): The list of pMHCs in the dataset.
        tokenizer (Tokenizer): The tokenizer to use for the dataset.
        use_pseudo (bool): Whether to use the pseudo MHC sequence or the full MHC sequence.
        use_cdr3 (bool): Whether to use the CDR3 sequence or the full TCR sequence.
        bert_format (bool): Whether to use the BERT format for the sequences.

    Returns:
        TCRpMHCdataset: A TCRpMHCdataset object
    """

    def __init__(self, source, target, use_mhc=False, use_pseudo=False, use_cdr3=False):
        self.source = source
        self.target = target
        self.tcrs = []
        self.pMHCs = []
        self.data = {}
        self.tcr_dict = dict()
        self.pmhc_dict = dict()
        self.use_mhc = use_mhc
        self.use_pseudo = use_pseudo
        self.use_cdr3 = use_cdr3
        
    def __len__(self):
        assert len(self.pMHCs) == len(self.tcrs)
        return len(self.pMHCs)
    
    def __repr__(self):
        return f'TCR:pMHC Dataset of N={self.__len__()}. Mode:{self.source} -> {self.target}.'
    
    def __str__(self):
        return f'TCR:pMHC Dataset of N={self.__len__()}. Mode:{self.source} -> {self.target}.'
    
    def __getitem__(self, idx):
        """Return a tuple of (TCR object, PMHC) for the given index."""
        tcr = self.tcrs[idx] 
        pmhc = self.pMHCs[idx]
        if self.source == 'pmhc':
            return pmhc, tcr
        else:    
            return tcr, pmhc
    
    def load_data_from_file(self, path_to_csv):
        df = pd.read_csv(path_to_csv)
        self.load_data_from_df(df)

    def load_data_from_df(self, df):
        for index, row in df.iterrows():
            try:
                ### 1. Create the TCR and pMHC objects
                tcr_i = TCR(cdr3a=row['CDR3a'], cdr3b=row['CDR3b'], 
                                 trav=row['TRAV'], trbv=row['TRBV'], traj=row['TRAJ'], trbj=row['TRBJ'],
                                 trad=row['TRAD'], trbd=row['TRBD'], tcra_full=row['TRA_stitched'], tcrb_full=row['TRB_stitched'],
                                 reference=row['Reference'])
                
                pMHC_i = pMHC(peptide=row['Epitope'], hla_allele=row['Allele'], reference=row['Reference'])
                
                ### 2. Hash the TCR and pMHC objects to get unique keys
                tcr_key = hash(tcr_i)
                pMHC_key = hash(pMHC_i)

                ### 3. If tcr exists then grab the existing tcr object and add the new information to it
                if tcr_key in self.tcr_dict.keys():
                    tcr_i = self.tcr_dict[tcr_key]
    
                # Add reference and cognate pMHC information to that TCR (assumes no duplicates of paired data)
                tcr_i.add_reference(row['Reference'])
                tcr_i.add_pMHC(pMHC_i)
                # Add the updated version back to the dictionary
                self.tcr_dict[tcr_key] = tcr_i

                ### 4. If pmhc exists then grab the existing pmhc object and add the new information to it
                if pMHC_key in self.pmhc_dict.keys():
                    pMHC_i = self.pmhc_dict[pMHC_key]
                # Add reference and cognate TCR
                pMHC_i.add_reference(row['Reference'])
                pMHC_i.add_tcr(tcr_i)
                # Add the updated version to the dictionary
                self.pmhc_dict[pMHC_key] = pMHC_i
                
                # Add TCR and PMHC to list **Updates the previous objects in the list thanks to pythons pointers**
                self.tcrs.append(tcr_i)
                self.pMHCs.append(pMHC_i)

            except:
                continue
    
    def to_df(self):
        """Return a dataframe representation of the dataset."""
        df = pd.DataFrame()
        df['CDR3a'] = [tcr.cdr3a for tcr in self.tcrs]
        df['CDR3b'] = [tcr.cdr3b for tcr in self.tcrs]
        df['TRAV'] = [tcr.trav for tcr in self.tcrs]
        df['TRBV'] = [tcr.trbv for tcr in self.tcrs]
        df['TRAJ'] = [tcr.traj for tcr in self.tcrs]
        df['TRBJ'] = [tcr.trbj for tcr in self.tcrs]
        df['TRAD'] = [tcr.trad for tcr in self.tcrs]
        df['TRBD'] = [tcr.trbd for tcr in self.tcrs]
        df['TRA_stitched'] = [tcr.tcra_full for tcr in self.tcrs]
        df['TRB_stitched'] = [tcr.tcrb_full for tcr in self.tcrs]
        df['Epitope'] = [pmhc.peptide for pmhc in self.pMHCs]
        df['Allele'] = [pmhc.allele for pmhc in self.pMHCs]
        df['Pseudo'] = [pmhc.pseudo for pmhc in self.pMHCs]
        df['MHC'] = [pmhc.mhc for pmhc in self.pMHCs]
        df['Reference'] = [list(pmhc.references)[-1] for pmhc in self.pMHCs]
        return df
    