import argparse
import json
from multiprocessing import Pool
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from sequence_models.collaters import StructureCollater, SimpleCollater
from sequence_models.gnn import BidirectionalStruct2SeqDecoder
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK, IUPAC_CODES
from sequence_models.convolutional import ByteNetLM
from sequence_models.utils import Tokenizer
from sequence_models.pdb_utils import process_coords
from atom3d.datasets import LMDBDataset


parser = argparse.ArgumentParser()
parser.add_argument('config_fpath', type=str, help='file path to config json for model')
parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
parser.add_argument('--gpu', '-g', type=int, default=0)
parser.add_argument('--cnn', action='store_true')
parser.add_argument('--no_gnn', action='store_true')


args = parser.parse_args()


def get_data(idx):
        item = dataset[idx]
        wt = item['original_atoms']
        mt = item['mutated_atoms']
        ens = wt.loc[0, 'ensemble'][:-4]
        chains = list(''.join(ens.split('_')[1:]))
        mt_ens = mt.loc[0, 'ensemble'][:-4]
        m = mt_ens.split(ens)[-1][1:]
        pos = int(m[2:-1]) - 1
        m_chain = m[1]
        aa = m[-1]
        label = int(item['label'])
        all_seqs = ''
        chain_coords = []
        for chain in chains:
            if chain != m_chain and args.no_gnn:
                continue
            if chain == m_chain:
                pos += len(all_seqs)
            s = ''
            p = 0
            for j, row in wt[wt['chain'] == chain].iterrows():
                r = row['residue']
                if r == p + 1:
                    p += 1
                    s += IUPAC_CODES[row['resname'][0] + row['resname'][1:].lower()]
            all_seqs += s
            coords = {
                'N': np.empty((len(s), 3)),
                'CA': np.empty((len(s), 3)),
                'C': np.empty((len(s), 3))
            }
            for j, row in wt[wt['chain'] == chain].iterrows():
                r = row['residue'] - 1
                e = row['name']
                if e in coords:
                    coords[e][r] = row[['x', 'y', 'z']].values
            chain_coords.append(coords)
        assert(all_seqs[pos] == m[0])
        coords = {atom: np.concatenate([c[atom] for c in chain_coords]) for atom in ['N', 'CA', 'C']}
        dist, omega, theta, phi = process_coords(coords)
        batch = [[all_seqs, torch.tensor(dist, dtype=torch.float),
                  torch.tensor(omega, dtype=torch.float),
                  torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
        _, nodes, edges, connections, edge_mask = collater(batch)
        sequence = (all_seqs, torch.tensor(tkn.tokenize(all_seqs)))
        structure = (nodes, edges, connections, edge_mask)
        mutation = (pos, aa)
        return structure, sequence, label, mutation



_ = torch.manual_seed(64)
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:' + str(args.gpu))

## Grab metadata
try:
    data_fpath = os.getenv('PT_DATA_DIR') + '/'
except:
    home = str(pathlib.Path.home())
    data_fpath = home + '/data/'
data_fpath += 'skempi/'

# grab model hyperparameters
with open(args.config_fpath, 'r') as f:
    config = json.load(f)
n_tokens = len(PROTEIN_ALPHABET)
d_model = config['d_model']
node_features = 10
edge_features = 11
dropout = config['dropout']
use_mpnn = config['use_mpnn']
n_layers = config['n_layers']
n_connections = config['n_connections']

# Prep results
out_fname = 'msp.csv'
if out_fname in os.listdir(args.out_fpath):
    record = pd.read_csv(args.out_fpath + out_fname)
else:
    record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
    if args.no_gnn:
        record.columns = ['loss', 'accuracy', 'tokens', 'step']
        temp = pd.DataFrame(np.array([[np.nan, np.nan, 0, 0]]), columns=record.columns)
    else:
        record.columns = ['step', 'loss']
        record['epoch'] = record.index
        record['epoch'] += 1
        temp = pd.DataFrame(np.array([[0, np.nan, 0]]), columns=record.columns)
    record = pd.concat([record.iloc[np.arange(len(record) - 1, -1, -1)], temp], ignore_index=True)
    record = record.astype({'step': int})
# Things we need to process the structure
dataset = LMDBDataset(data_fpath + '/split-by-sequence-identity-30/data/test/')
collater = StructureCollater(SimpleCollater(PROTEIN_ALPHABET), n_connections=n_connections)
tkn = Tokenizer(alphabet=PROTEIN_ALPHABET)
print('Loading structures...')

n = len(dataset)
# n = 1
data = [get_data(i) for i in range(n)]
# with Pool(64) as pool:
#     data = list(tqdm(pool.imap(get_data, range(n)), total=n))

for idx, row in record.iterrows():
    step = str(int(row['step']))
    wd = args.map_dir + 'checkpoint' + step + '.tar'
    if args.cnn:
        d_cnn = 1280
        n_cnn_layers = 56
        kernel_size = 5
        r = 128
        slim = False
        d_embed = 8
        activation = 'gelu'
        causal = False
        pad_idx = PROTEIN_ALPHABET.index(PAD)
        cnn = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
                        slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)

    if not args.no_gnn:
        gnn = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                           d_model, num_decoder_layers=n_layers,
                                           dropout=dropout, use_mpnn=use_mpnn,
                                             pe=False, one_hot_src=(not args.cnn)).to(device)
    if int(row['step'] != 0):
        sd = torch.load(wd, map_location=device)
        if args.cnn and args.no_gnn:
            sd = sd['model_state_dict']
            sd = {k.split('module.')[1]: v for k, v in sd.items()}
            cnn.load_state_dict(sd)
            cnn = cnn.to(device).eval()
        elif args.cnn:
            gnn_sd = sd['decoder_state_dict']
            gnn.load_state_dict(gnn_sd)
            gnn = gnn.to(device).eval()
            cnn_sd = sd['encoder_state_dict']
            cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
            cnn.load_state_dict(cnn_sd)
            cnn = cnn.to(device).eval()
        else:
            sd = sd['model_state_dict']
            sd = {k.split('module.')[1]: v for k, v in sd.items()}
            gnn.load_state_dict(sd)
            gnn = gnn.to(device).eval()

    pre = np.zeros(len(data))
    labels = np.zeros(len(data))
    for i, (structure, sequence, label, mutation) in enumerate(tqdm(data)):
        nodes, edges, connections, edge_mask = structure
        nodes = nodes.to(device)
        edges = edges.to(device)
        connections = connections.to(device)
        edge_mask = edge_mask.to(device)
        wt, tokenized = sequence
        pos, aa = mutation
        pos = [pos]
        aa = [PROTEIN_ALPHABET.index(aa)]
        seq = ''
        start = 0
        p = pos[0]
        seq += wt[start:p]
        start = p + 1
        seq += MASK
        seq += wt[start:]
        src = torch.tensor(tkn.tokenize(seq)).unsqueeze(0).to(device)
        with torch.no_grad():
            if args.cnn:
                src = cnn(src)
                if args.no_gnn:
                    output = F.log_softmax(src, dim=-1)
            if not args.no_gnn:
                output = F.log_softmax(gnn(nodes, edges, connections, src, edge_mask), dim=-1)  # 1, ell, t
            pl = output[0, pos].cpu()
            wt_pl = output[torch.zeros(len(pos)).long(), pos, tokenized[list(pos)]].cpu()
            likelihood = (pl - wt_pl.unsqueeze(-1)).numpy()
        pre[i] = likelihood[np.arange(len(pos)), np.array(aa)].sum()
        labels[i] = label
    # print(pre)
    # print(labels)
    df = pd.DataFrame()
    df['prediction'] = pre
    df['label'] = labels
    auc = roc_auc_score(df['label'], df['prediction'])
    record.loc[idx, 'auc'] = auc
    print(record)
    record.to_csv(args.out_fpath + out_fname, index=False)

    # Save predictions
    df.to_csv(args.out_fpath + 'predictions%s.csv' %step, index=False)





