import pandas as pd
import os
import sys
from constants import *
sys.path.append("../utils/")
from fold import fold
from tqdm import tqdm
from Bio.PDB import PDBParser, PDBIO
from Bio.PDB.Chain import Chain
from utils import read_yaml


def modify_chain_id(pdb_file, output_file, chain_id, new_chain_id, num_residues_to_change):
    parser = PDBParser()
    structure = parser.get_structure('structure', pdb_file)
    
    model = structure[0]

    # Get the chain to modify
    old_chain = model[chain_id]
    if new_chain_id in model:
        #Already fixed
        return

    # Create a new chain
    new_chain = Chain(new_chain_id)
    model.add(new_chain)

    # Get the residues in the specified chain
    residues = list(old_chain)

    # Modify the chain ID for the last num_residues_to_change residues
    for residue in residues[-num_residues_to_change:]:
        # Remove the residue from the old chain and add it to the new chain
        old_chain.detach_child(residue.id)
        new_chain.add(residue)
    
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_file)


def fold_heavy_light(property_th_lb=DEFAULT_PAIRING_SETTINGS['property_th_lb'],
                     property_th_ub=DEFAULT_PAIRING_SETTINGS['property_th_ub'],
                     edist_th=DEFAULT_PAIRING_SETTINGS['edist_th'],
                     property_to_match=DEFAULT_PAIRING_SETTINGS['property_to_match'],
                     min_prop=None,
                     names=['aalphabio'],
                     paired_dir=DEFAULT_DATASET_PATH,
                     pdb_dir='/data/mahajs17/Propen/'
                     ):

    suffix = ''
    if min_prop is not None:
        suffix = f'_minprop{min_prop}'
    
    for name in names:
        pdb_path = f"{pdb_dir}/{name}/EsmFold"
        os.makedirs(pdb_path, exist_ok=True)
        outfile_plddts = f"{pdb_dir}/{name}/{name}_esmfold.csv"
        if os.path.exists(outfile_plddts):
            outfile_plddts = outfile_plddts.replace('.csv', '_1.csv')

        if name == 'skempi':
            datafile = f'{paired_dir}/{name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{edist_th}{suffix}.parquet'
        else:
            datafile = f'{paired_dir}/{name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED{suffix}.parquet'
        print(datafile)
        df_paired = pd.read_parquet(datafile)
        df = pd.DataFrame()
        df['seqid'] = df_paired['first_seqid'].values.tolist() + df_paired['second_seqid'].values.tolist()
        df['heavy'] = df_paired['first_HeavyAA'].values.tolist() + df_paired['second_HeavyAA'].values.tolist()
        df['light'] = df_paired['first_LightAA'].values.tolist() + df_paired['second_LightAA'].values.tolist()
        df = df.drop_duplicates()
        print(df.shape[0])
        plddts = []
        for i, row in tqdm(df.iterrows()):
            sequence = row['heavy'] + ':' + row['light']
            
            seqid = row['seqid']
            outfile = f'{pdb_path}/{seqid}.pdb'
            if os.path.exists(outfile):
                continue
            plddt = fold(sequence, outfile, add_special_tokens=False)
            modify_chain_id(outfile, outfile, 'A', 'B', len(row['light']))
            plddts.append(plddt.mean().item())
        df['plddt'] = plddts
        df.to_csv(outfile_plddts, index=False)


def fold_heavy_light_single(datafile, pdb_dir='/data/mahajs17/Propen/tests', seqid_col='seqid'):

    pdb_path = f"{pdb_dir}//EsmFold"
    os.makedirs(pdb_path, exist_ok=True)
    outfile_plddts = f"{pdb_dir}/esmfold.csv"
    if os.path.exists(outfile_plddts):
        outfile_plddts = outfile_plddts.replace('.csv', '_1.csv')

    if datafile.endswith('.parquet'):
        df = pd.read_parquet(datafile)
    else:
        df = pd.read_csv(datafile)
    df = df.drop_duplicates()
    print(df.shape[0])
    plddts = []
    for i, row in tqdm(df.iterrows()):
        sequence = row['fv_heavy'] + ':' + row['fv_light']
        
        seqid = row[seqid_col]
        outfile = f'{pdb_path}/{seqid}.pdb'
        if os.path.exists(outfile):
            continue
        plddt = fold(sequence, outfile, add_special_tokens=False)
        modify_chain_id(outfile, outfile, 'A', 'B', len(row['fv_light']))
        plddts.append(plddt.mean().item())
    df['plddt'] = plddts
    df.to_csv(outfile_plddts, index=False)


def fix_chains():
    for name in names:
        pdb_path = f"/data/mahajs17/Propen/{name}/EsmFold"
        os.makedirs(pdb_path, exist_ok=True)
        outfile_plddts = f"/data/mahajs17/Propen/{name}/{name}_esmfold.csv"
        if os.path.exists(outfile_plddts):
            outfile_plddts = outfile_plddts.replace('.csv', '_1.csv')

        datafile = f'{DEFAULT_DATASET_PATH}/{name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{property_to_match}.parquet'
        df_paired = pd.read_parquet(datafile)
        df = pd.DataFrame()
        df['seqid'] = df_paired['first_seqid'].values.tolist() + df_paired['second_seqid'].values.tolist()
        df['heavy'] = df_paired['first_HeavyAA'].values.tolist() + df_paired['second_HeavyAA'].values.tolist()
        df['light'] = df_paired['first_LightAA'].values.tolist() + df_paired['second_LightAA'].values.tolist()
        df = df.drop_duplicates()
        for _, row in tqdm(df.iterrows()):
            
            seqid = row['seqid']
            outfile = f'{pdb_path}/{seqid}.pdb'
            if os.path.exists(outfile):
                print('modifying' , outfile)
                modify_chain_id(outfile, outfile, 'A', 'B', len(row['light']))


