import argparse
import json
from datetime import datetime
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from sequence_models.collaters import Tokenizer
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
from sequence_models.convolutional import ByteNetLM
from sequence_models.utils import parse_fasta


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--uniprot', action='store_true')

    args = parser.parse_args()
    args.world_size = 1
    train(0, args)


def train(gpu, args):
    _ = torch.manual_seed(0)
    torch.cuda.set_device(gpu + args.gpu)
    device = torch.device('cuda:' + str(gpu + args.gpu))

    ## Grab metadata
    try:
        data_fpath = os.getenv('PT_DATA_DIR') + '/'
    except:
        home = str(pathlib.Path.home())
        data_fpath = home + '/data/'
    data_fpath += 'deepsequence/'
    meta_df = pd.read_csv(data_fpath + 'metadata.csv')

    # Prep results
    if 'dms.csv' in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + 'dms.csv')
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        skip = len(record) // 10
        idx = list(range(len(record) - 1, 0, -skip))
        record = record.iloc[idx]
        record.columns = ['loss', 'accuracy', 'tokens', 'step']
        temp = pd.DataFrame(np.array([[np.nan, np.nan, 0, 0]]), columns=record.columns)
        record = pd.concat([record, temp], ignore_index=True)
    # Get the wt sequences and offsets from alignments
    offsets = {}
    wts = {}
    tkn = Tokenizer(alphabet=PROTEIN_ALPHABET)
    if not args.uniprot:
        alignments = os.listdir(data_fpath + 'msas/')
        for fname in alignments:
            seqs, names = parse_fasta(data_fpath + 'msas/' + fname, return_names=True)
            wt = seqs[0].upper()
            name_split = fname.split('_')
            protein_name = []
            for n in name_split:
                try:
                    int(n)
                    break
                except ValueError:
                    protein_name.append(n)
            protein_name = '_'.join(protein_name)
            wts[protein_name] = wt
            offset = int(names[0].split('/')[-1].split('-')[0])
            offsets[protein_name] = offset
    else:
        seqs, names = parse_fasta(data_fpath + 'uniprot_sequences.fasta', return_names=True)
        for seq, name in zip(seqs, names):
            wt = seq.upper()
            name_split = name.split('_')
            protein_name = []
            for n in name_split:
                try:
                    int(n)
                    break
                except ValueError:
                    protein_name.append(n)
            protein_name = '_'.join(protein_name)
            wts[protein_name] = wt
            offsets[protein_name] = 1

    # grab model hyperparameters
    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_embedding = config['d_embed']
    d_model = config['d_model']
    n_layers = config['n_layers']
    kernel_size = config['kernel_size']
    activation = config['activation']
    slim = config['slim']
    r = config['r']

    with tqdm(total=len(record) * len(meta_df)) as pbar:
        for idx, row in record.iterrows():
            wd = args.map_dir + 'checkpoint' + str(int(row['step'])) + '.tar'
            bytenet = ByteNetLM(n_tokens, d_embedding, d_model, n_layers, kernel_size, r,
                                dropout=0.0, activation=activation,
                                causal=False, padding_idx=PROTEIN_ALPHABET.index(PAD), final_ln=True, slim=slim)
            if int(row['step']) != 0:
                sd = torch.load(wd, map_location=device)
                sd = sd['model_state_dict']
                sd = {k.split('module.')[1]: v for k, v in sd.items()}
                bytenet.load_state_dict(sd)
            bytenet = bytenet.to(device).eval()
            for _, (name, fitness) in meta_df.iterrows():
                if name in record.columns:
                    if not np.isnan(record.loc[idx, name]):
                        pbar.update(1)
                        continue
                if '_BRCT' in name and not args.uniprot:
                    protein_name = name
                else:
                    protein_name = '_'.join(name.split('_')[:2])
                    if protein_name[:4] == 'TIM_':
                        protein_name = 'TRPC_' + protein_name[4:]
                wt = wts[protein_name]
                tokenized = torch.tensor(tkn.tokenize(wt))
                offset = offsets[protein_name]
                pseudo_likelihood = {}
                ell = len(wt)
                df = pd.read_excel(data_fpath + 'NIHMS1014772-supplement-Supplemental_2.xlsx',
                                   sheet_name=name, engine='openpyxl')
                df = df.dropna(axis=0, subset=[fitness]).reset_index()
                tgt = df[fitness].to_numpy()
                pre = np.zeros(len(tgt))
                skipped = 0
                for i, row in df.iterrows():
                    muts = row['mutant'].split(':')
                    aas = []
                    pos = []
                    for mut in muts:
                        if mut == 'WT':
                            mut = wt[0] + str(offset) + wt[0]
                        m = mut[-1]
                        # Deal with deletions
                        if m == '_':
                            m = PAD
                        # Skip things that are past the end?!
                        j = int(mut[1:-1]) - offset
                        if j >= ell or j < 0:
                            continue
                        aas.append(PROTEIN_ALPHABET.index(m))
                        pos.append(j)
                    if len(pos) == 0:
                        tgt = np.concatenate([tgt[:i], tgt[i + 1:]])
                        pre = np.concatenate([pre[:i], pre[i + 1:]])
                        skipped += 1
                        continue
                    pos = np.array(pos)
                    sort_idx = pos.argsort()
                    aas = [aas[ix] for ix in sort_idx]
                    pos = tuple(pos[sort_idx])
                    if pos not in pseudo_likelihood:
                        seq = ''
                        start = 0
                        for p in pos:
                            seq += wt[start:p]
                            start = p + 1
                            seq += MASK
                        seq += wt[start:]
                        src = torch.tensor(tkn.tokenize(seq)).unsqueeze(0).to(device)
                        with torch.no_grad():
                            output = F.log_softmax(bytenet(src, input_mask=torch.ones_like(src).unsqueeze(-1)), dim=-1)
                            pl = output[0, pos].cpu()
                            wt_pl = output[torch.zeros(len(pos)).long(), pos, tokenized[list(pos)]].cpu()
                            pseudo_likelihood[pos] = (pl - wt_pl.unsqueeze(-1)).numpy()
                    pre[i - skipped] = pseudo_likelihood[pos][np.arange(len(pos)), np.array(aas)].sum()

                rho = spearmanr(pre, tgt)
                record.loc[idx, name] = rho.correlation
                record.to_csv(args.out_fpath + 'dms.csv', index=False)
                pbar.update(1)


if __name__ == '__main__':
    main()