import dotenv
import hydra
import wandb
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from typing import Optional, Union
import wandb
import torch
import os
import pandas as pd
import pytorch_lightning as pl
from affinityenhancer.hydra._instantiate_datamodule import instantiate_datamodule
from affinityenhancer.hydra._instantiate_callbacks import instantiate_callbacks
from affinityenhancer.hydra._instantiate_model import instantiate_model
from affinityenhancer.sample.sample_propen import sample

from prescient.transforms.functional import anarci_numbering
from prescient.constants import CDR_RANGES_AHO, LENGTH_FV_HEAVY_AHO, RANGES_AHO
import edlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime
from affinityenhancer.utils.logos import plot_logos
dotenv.load_dotenv(".env")

# Register a custom resolver to convert relative paths to absolute paths
def resolve_to_absolute_path(path: str) -> str:
    return to_absolute_path(path)

#OmegaConf.register_new_resolver("to_absolute_path", resolve_to_absolute_path)

def aho_to_chain(chain_seq, index):
    pass


def region_masker(regions, chain_seq, aho=False):
    if isinstance(list(regions)[0], int):
        redesign_accumulator = set(regions)
        print(redesign_accumulator)
        regions = ''.join([str(t) for t in regions])
    else:
        unrecognized_regions = set(regions) - set(RANGES_AHO)
            
        assert (
            not unrecognized_regions
        ), f"Could not parse these redesign regions: {unrecognized_regions}"
        redesign_accumulator = set()
        for region in regions:
            chain = region[0]
            # I am using heavy chain length max == 151 instead of 149 in the standard definition
            aho_start, aho_end = RANGES_AHO[region] + 2 * int(chain == "L")
            shift = 151 * int(chain == "L")

            redesign_accumulator |= set((np.arange(aho_start, aho_end) + shift).tolist())
    
    if aho:
        mask_idxs = sorted(
            set(np.arange(0, 301)) - redesign_accumulator
        )  # don't update these residues
    else:
        redesign_accumulator_nonaho = [aho_to_chain[chain_seq, t] for t in redesign_accumulator]
        mask_idxs = sorted(
            set(np.arange(0, 301)) - redesign_accumulator_nonaho
        )  # don't update these residues

    print(f"Redesigning {regions}, i.e., not {mask_idxs}")
    return mask_idxs


def chain_masker(chain_name):
    mask = torch.zeros((301,))
    if chain_name in ['heavy', 'H', 'Heavy']:
        mask[151:] = 1
    else:
        mask[:151] = 1

    print(f'mask {chain_name}: {mask}')
    return mask


@hydra.main(version_base=None, config_path="../configs", config_name="sample_propen")
def run(cfg: DictConfig) -> None:
    
    log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True)
    print(cfg)
    wandb.require("service")
    
    os.makedirs(cfg.outdir, exist_ok=True)
    d = datetime.now().isoformat(timespec='minutes').replace(':', '')
    with open(f'{cfg.outdir}/cfg_{d}.txt', 'w') as f:
        str_log = '\n'.join([f'{key}:{val}' for key,val in log_cfg.items()])
        f.write(str_log +'\n')
    
    hydra.utils.instantiate(cfg.setup)

    if cfg.get("torch"):
        hydra.utils.instantiate(cfg.torch)
    
    model = hydra.utils.instantiate(cfg.enhancers, _recursive_=False)
    assert cfg.ckpt_file is not None
    model = model.__class__.load_from_checkpoint(
                    cfg.ckpt_file, strict=False)

    datamodule = hydra.utils.instantiate(cfg.data,  _recursive_=False)
    datamodule.setup(stage='predict')

    dataloader = datamodule.predict_dataloader()

    for siter in cfg.iterations:
        for temp in cfg.temp:
            seqs = {'heavy':[], 'light':[]}
            seqs_wt = {'heavy':[], 'light':[]}
            eds = []
            suffix = f'iter{siter}_t{temp}_N{cfg.samples}'
            for _, batch in enumerate(dataloader):
                #needs to be moved - apply per ab - same mask cannot be applied to the entire batch since sequences are not aho
                #mask = region_masker(cfg.regions, batch)
                # simple chain based masker
                mask = chain_masker(cfg.mask_chains) if cfg.mask_chains else None
                outputs = sample(batch,
                                model,
                                save_trajectory=cfg.save_trajectory,
                                iterations=siter,
                                enhance_mode=cfg.enhance_mode,
                                sample_mode=cfg.sample_mode,
                                temp=temp,
                                samples=cfg.samples,
                                mask=mask
                                )
                if cfg.save_trajectory:
                    sequences, sequences_wt, ed_distances, trajectory = outputs
                else:
                    sequences, sequences_wt, ed_distances = outputs

                samples = cfg.samples if cfg.sample_mode == 'logits' else 1
                eds += list(ed_distances)
                
                for j, wt in enumerate(sequences_wt):
                    seqs_hl = sequences[j*samples:(j+1)*samples]
                    for k, (wt_ch_seq, chain_name) in enumerate(zip(wt, ['heavy', 'light'])):
                        with open(f"{cfg.outdir}/sampled_sequences_{suffix}_{chain_name}_{j}.txt", 'w') as f:
                            sampled_ch_seq = [sch[k] for sch in seqs_hl]
                            f.write(wt_ch_seq + '\n' + '\n'.join(list(set([sch[k] for sch in seqs_hl]))) + '\n')
                        seqs_wt[chain_name] += [wt_ch_seq for _ in range(samples)]
                        seqs[chain_name] += sampled_ch_seq

            assert len(ed_distances) == len(seqs['heavy']) == len(seqs_wt['heavy'])

            df = pd.DataFrame()
            df['fv_heavy'] = seqs['heavy']
            df['fv_light'] = seqs['light']
            # ADD AHO
            heavy_aho = anarci_numbering(seqs['heavy'])
            light_aho = anarci_numbering(seqs['light'])
            df['fv_heavy_aho'] = heavy_aho
            df['fv_light_aho'] = light_aho
            
            df['fv_heavy_seed'] = seqs_wt['heavy']
            df['fv_heavy_seed_aho'] = anarci_numbering(seqs_wt['heavy'])
            df['fv_light_seed'] = seqs_wt['light']
            df['fv_light_seed_aho'] = anarci_numbering(seqs_wt['light'])

            df['edit_distance'] = ed_distances
            # ADD PER CDR ED
            df['edit_distance_heavy'] = [edlib.align(row['fv_heavy'], 
                                                     row['fv_heavy_seed'])['editDistance'] 
                                                     for i, row in df.iterrows()]
            df['edit_distance_light'] = [edlib.align(row['fv_light'], 
                                                     row['fv_light_seed'])['editDistance'] 
                                                     for i, row in df.iterrows()]
            
            for cdr, cdr_range in CDR_RANGES_AHO.items():
                chain = 'heavy' if cdr.startswith('H') else 'light'
                print(cdr, cdr_range, chain)
                r1, r2 = cdr_range[0], cdr_range[1]+1
                df[f'edit_distance_{cdr}'] = [edlib.align(row[f'fv_{chain}_aho'][r1:r2].replace('-',''), 
                                                     row[f'fv_{chain}_seed_aho'][r1:r2].replace('-',''))['editDistance'] 
                                                     for i, row in df.iterrows()]
            df['temperature'] = temp
            df['iterations'] = siter

            #add edlib calc
            df = df.drop_duplicates(['fv_heavy', 'fv_light'])
            df.to_csv(f'{cfg.outdir}/samples_{suffix}.csv', index=False)

            plt.rcdefaults()
            sns.histplot(df, x="edit_distance_light", discrete=True)
            plt.savefig(f"{cfg.outdir}/distribution_EDlight_{suffix}.png")
            plt.close()
            sns.histplot(df, x="edit_distance_heavy", discrete=True)
            plt.savefig(f"{cfg.outdir}/distribution_EDheavy_{suffix}.png")
            plt.close()

            fig, axes = plt.subplots(nrows=2, ncols=4)
            axes = axes.flatten()
            for i, cdr in enumerate(CDR_RANGES_AHO):
                sns.histplot(df, x=f"edit_distance_{cdr}", ax=axes[i], discrete=True)
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            fig.tight_layout()
            plt.savefig(f"{cfg.outdir}/distribution_EDCDRs_{suffix}.png")
            plt.close()
            plt.rcdefaults()
            for chain in ['heavy', 'light']:
                for i, ref_seq in enumerate(df[f'fv_{chain}_seed_aho'].unique()):
                    df_filter = df[df[f'fv_{chain}_seed_aho']==ref_seq]
                    sequences = df_filter[f'fv_{chain}_aho'].values.tolist()
                    sequences = [s for s in sequences if len(s)==len(ref_seq)]
                    plt.rcdefaults()
                    plot_logos(sequences,
                               ref_seq=ref_seq,
                               logo_file_base=f'{cfg.outdir}/logos_seed{i}_{chain}_{suffix}',
                               chain=chain.upper()[0]
                               )
                
    print('Done')
    wandb.finish()

run()
