import argparse
import json
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from sequence_models.collaters import StructureCollater, SimpleCollater
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
from sequence_models.convolutional import ByteNetLM
from sequence_models.gnn import BidirectionalStruct2SeqDecoder
from sequence_models.utils import parse_fasta, Tokenizer
from sequence_models.pdb_utils import parse_PDB, process_coords


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--pdb', action='store_true')
    parser.add_argument('--cnn', action='store_true')

    args = parser.parse_args()
    args.world_size = 1
    train(0, args)


def train(gpu, args):
    _ = torch.manual_seed(0)
    torch.cuda.set_device(gpu + args.gpu)
    device = torch.device('cuda:' + str(gpu + args.gpu))

    ## Grab metadata
    try:
        data_fpath = os.getenv('PT_DATA_DIR') + '/'
    except:
        home = str(pathlib.Path.home())
        data_fpath = home + '/data/'
    if args.pdb:
        pdb_fpath = data_fpath + 'pdb/'
        n_pdbs = 1
    else:
        pdb_fpath = data_fpath + 'pdb/output_unzipped/'
        n_pdbs = 5
    data_fpath += 'deepsequence/'
    meta_df = pd.read_csv(data_fpath + 'metadata.csv')

    # Prep results
    out_fname = 'dms.csv'
    if out_fname in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + out_fname)
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        record.columns = ['step', 'loss']
        record['epoch'] = record.index
        record['epoch'] += 1
        temp = pd.DataFrame(np.array([[0, np.nan, 0]]), columns=record.columns)
        min_idx = record.sort_values('loss').index
        record = record.loc[:min_idx[0]]
        first_loss = record.iloc[0, 1]
        last_loss = record.iloc[-1, 1]
        bins = np.linspace(first_loss, last_loss, 10)[1:-1]
        j = 0
        idx = []
        for i, row in record.iterrows():
            if row['loss'] < bins[j]:
                j += 1
                idx.append(i)
                if j == len(bins):
                    break
        record = pd.concat([record.iloc[-1:], temp, record.iloc[[0] + idx]], ignore_index=True)
        # record = pd.concat([record.iloc[np.arange(len(record) - 1, -1, -1)], temp], ignore_index=True)
        record = record.astype({'step': int})
        records = {rank: record.copy() for rank in range(n_pdbs + 1)}

    # grab model hyperparameters
    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_model = config['d_model']
    node_features = 10
    edge_features = 11
    dropout = config['dropout']
    use_mpnn = config['use_mpnn']
    n_layers = config['n_layers']
    n_connections = config['n_connections']

    # Get the wt sequences and offsets from alignments
    # Also get the pdb files
    offsets = {}
    wts = {}
    pdbs = {}
    tkn = Tokenizer(alphabet=PROTEIN_ALPHABET)
    collater = StructureCollater(SimpleCollater(PROTEIN_ALPHABET), n_connections=n_connections)
    alignments = os.listdir(data_fpath + 'msas/')
    if args.pdb:
        pdb_map = pd.read_excel(data_fpath + 'real_pdbs/aln2pdb.xlsx', engine='openpyxl')
        pdb_map = pdb_map[pdb_map['check'] == pdb_map['check.1']]
    for fname in alignments:
        if args.pdb:
            for i, row in pdb_map.iterrows():
                if row['aln_name'] in fname:
                    break
            else:
                continue
        seqs, names = parse_fasta(data_fpath + 'msas/' + fname, return_names=True)
        wt = seqs[0].upper()
        name_split = fname.split('_')
        protein_name = []
        for n in name_split:
            try:
                int(n)
                break
            except ValueError:
                protein_name.append(n)
        protein_name = '_'.join(protein_name)
        offset = int(names[0].split('/')[-1].split('-')[0])
        offsets[protein_name] = offset
        if args.pdb:
            # offsets[protein_name] -= row['pdb_s']
            offsets[protein_name] += row['aln_s']
            fname = row['pdbid'].upper() + '.pdb'
        else:
            if 'parEparD' not in fname:
                pdb_path = pdb_fpath + fname.split('b0.')[0] + 'b0.result/'
            else:
                pdb_path = pdb_fpath + 'parEparD_3.result/'
            try:
                pdb_files = os.listdir(pdb_path)
            except FileNotFoundError:
                continue
            pdb_files = [p for p in pdb_files if '.pdb' in p]
        current_dict = {}
        for rank in range(n_pdbs):
            if args.pdb:
                pdb_path = pdb_fpath
                chain = 'ABCDEFG'[row['chain_idx']]
                if fname == '6R5K.pdb':
                    chain = 'A'
            else:
                fname = [p for p in pdb_files if 'rank_%d' %(rank + 1) in p][0]
                chain = 'A'
            coords, _, _ = parse_PDB(pdb_path + fname, chain=chain)
            if args.pdb:
                coords = coords[row['pdb_s']: row['pdb_e']]
                wt = wt[row['aln_s']: row['aln_e']]

            coords = {
                'N': coords[:, 0],
                'CA': coords[:, 1],
                'C': coords[:, 2]
            }
            dist, omega, theta, phi = process_coords(coords)
            batch = [[wt, torch.tensor(dist, dtype=torch.float),
                      torch.tensor(omega, dtype=torch.float),
                      torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
            _, nodes, edges, connections, edge_mask = collater(batch)
            nodes = nodes.to(device)
            edges = edges.to(device)
            connections = connections.to(device)
            edge_mask = edge_mask.to(device)
            current_dict[rank] = (nodes, edges, connections, edge_mask)
        pdbs[protein_name] = current_dict
        wts[protein_name] = wt
        # seq = seqs[0].upper()
        # wts[protein_name] = wt[:row['pdb_s']] + seq[row['aln_s']: row['aln_e']] + wt[row['pdb_e']:]


    with tqdm(total=len(record) * len(meta_df)) as pbar:
        for idx, row in record.iterrows():
            if args.cnn:
                d_cnn = 1280
                n_cnn_layers = 56
                kernel_size = 5
                r = 128
                slim = False
                d_embed = 8
                activation = 'gelu'
                causal = False
                pad_idx = PROTEIN_ALPHABET.index(PAD)
                cnn = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
                                slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)

            gnn = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                                 d_model, num_decoder_layers=n_layers,
                                                 dropout=dropout, use_mpnn=use_mpnn,
                                                 pe=False, one_hot_src=(not args.cnn)).to(device)
            step = str(int(row['step']))
            wd = args.map_dir + 'checkpoint' + step + '.tar'
            if step != '0':
                sd = torch.load(wd, map_location=device)
                if args.cnn:
                    gnn_sd = sd['decoder_state_dict']
                    gnn.load_state_dict(gnn_sd)
                    gnn = gnn.to(device).eval()
                    cnn_sd = sd['encoder_state_dict']
                    cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
                    cnn.load_state_dict(cnn_sd)
                    cnn = cnn.to(device).eval()
                else:
                    sd = sd['model_state_dict']
                    sd = {k.split('module.')[1]: v for k, v in sd.items()}
                    gnn.load_state_dict(sd)
                    gnn = gnn.to(device).eval()
            for _, (name, fitness) in meta_df.iterrows():
                if name in record.columns:
                    if not np.isnan(record.loc[idx, name]):
                        pbar.update(1)
                        continue
                if '_BRCT' in name:
                    protein_name = name
                else:
                    protein_name = '_'.join(name.split('_')[:2])
                    if protein_name[:4] == 'TIM_':
                        protein_name = 'TRPC_' + protein_name[4:]
                if protein_name not in pdbs:
                    continue
                print(protein_name)
                wt = wts[protein_name]
                tokenized = torch.tensor(tkn.tokenize(wt))
                offset = offsets[protein_name]
                current_pdbs = pdbs[protein_name]
                pseudo_likelihood = {rank: {} for rank in range(n_pdbs)}
                ell = len(wt)
                df = pd.read_excel(data_fpath + 'NIHMS1014772-supplement-Supplemental_2.xlsx',
                                   sheet_name=name, engine='openpyxl')
                df = df.dropna(axis=0, subset=[fitness]).reset_index()
                tgt = df[fitness].to_numpy()
                pre = np.zeros((n_pdbs, len(tgt)))
                skipped = 0
                for i, row in df.iterrows():
                    muts = row['mutant'].split(':')
                    aas = []
                    pos = []
                    for mut in muts:
                        if mut == 'WT':
                            mut = wt[0] + str(offset) + wt[0]
                        m = mut[-1]
                        # Deal with deletions
                        if m == '_':
                            m = PAD
                        # Skip things that are past the end?!
                        j = int(mut[1:-1]) - offset
                        if j >= ell or j < 0:
                            continue
                        if wt[j] != mut[0]:
                            print(protein_name, wt[j], mut[0], j)
                        aas.append(PROTEIN_ALPHABET.index(m))
                        pos.append(j)
                    if len(pos) == 0:
                        tgt = np.concatenate([tgt[:i], tgt[i + 1:]])
                        pre = np.concatenate([pre[:, :i], pre[:, i + 1:]], axis=1)
                        skipped += 1
                        continue
                    pos = np.array(pos)
                    sort_idx = pos.argsort()
                    aas = [aas[ix] for ix in sort_idx]
                    pos = tuple(pos[sort_idx])
                    if pos not in pseudo_likelihood[0]:
                        seq = ''
                        start = 0
                        for p in pos:
                            seq += wt[start:p]
                            start = p + 1
                            seq += MASK
                        seq += wt[start:]
                        src = torch.tensor(tkn.tokenize(seq)).unsqueeze(0).to(device)
                        # print(name, nodes.shape, src.shape, tokenized.shape, muts, pos, offset)
                        with torch.no_grad():
                            if args.cnn:
                                src = cnn(src)
                            for rank in range(n_pdbs):
                                nodes, edges, connections, edge_mask = current_pdbs[rank]
                                # print(nodes.shape, src.shape, len(wt), pos, offset)
                                output = F.log_softmax(gnn(nodes, edges, connections, src, edge_mask), dim=-1)
                                pl = output[0, pos].cpu()
                                wt_pl = output[torch.zeros(len(pos)).long(), pos, tokenized[list(pos)]].cpu()
                                pseudo_likelihood[rank][pos] = (pl - wt_pl.unsqueeze(-1)).numpy()
                    for rank in range(n_pdbs):
                        pre[rank, i - skipped] = pseudo_likelihood[rank][pos][np.arange(len(pos)), np.array(aas)].sum()
                if n_pdbs > 1:
                    for rank in range(n_pdbs):
                        rho = spearmanr(pre[rank], tgt)
                        records[rank].loc[idx, name] = rho.correlation
                        records[rank].to_csv(args.out_fpath + 'dms%d.csv' %rank, index=False)
                pre = pre.mean(axis=0)
                rho = spearmanr(pre, tgt)
                records[n_pdbs].loc[idx, name] = rho.correlation
                records[n_pdbs].to_csv(args.out_fpath + 'dms.csv', index=False)
                print(rho)
                pbar.update(1)


if __name__ == '__main__':
    main()