import argparse
import json
from datetime import datetime
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from sequence_models.collaters import StructureCollater, SimpleCollater
from sequence_models.gnn import BidirectionalStruct2SeqDecoder
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK
from sequence_models.convolutional import ByteNetLM
from sequence_models.utils import parse_fasta, Tokenizer
from sequence_models.pdb_utils import parse_PDB, process_coords


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--cnn', action='store_true')
    parser.add_argument('--no_gnn', action='store_true')
    parser.add_argument('--use_wt', action='store_true')
    parser.add_argument('--af2', action='store_true')


    args = parser.parse_args()
    args.world_size = 1
    train(0, args)


def train(gpu, args):
    _ = torch.manual_seed(32)
    torch.cuda.set_device(gpu + args.gpu)
    device = torch.device('cuda:' + str(gpu + args.gpu))

    ## Grab metadata
    try:
        data_fpath = os.getenv('PT_DATA_DIR') + '/'
    except:
        home = str(pathlib.Path.home())
        data_fpath = home + '/data/'
    pdb_fpath = data_fpath + 'pdb/'
    data_fpath += 'rbd/'

    # grab model hyperparameters
    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_model = config['d_model']
    node_features = 10
    edge_features = 11
    dropout = config['dropout']
    use_mpnn = config['use_mpnn']
    n_layers = config['n_layers']
    n_connections = config['n_connections']

    # Prep results
    out_fname = 'rbd.csv'
    if out_fname in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + out_fname)
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        if args.no_gnn:
            record.columns = ['loss', 'accuracy', 'tokens', 'step']
            temp = pd.DataFrame(np.array([[np.nan, np.nan, 0, 0]]), columns=record.columns)
        else:
            record.columns = ['step', 'loss']
            record['epoch'] = record.index
            record['epoch'] += 1
            temp = pd.DataFrame(np.array([[0, np.nan, 0]]), columns=record.columns)
        min_idx = record.sort_values('loss').index
        record = record.loc[:min_idx[0]]
        first_loss = record.iloc[0, 1]
        last_loss = record.iloc[-1, 1]
        bins = np.linspace(first_loss, last_loss, 10)[1:-1]
        j = 0
        idx = []
        for i, row in record.iterrows():
            if row['loss'] < bins[j]:
                j += 1
                idx.append(i)
                if j == len(bins):
                    break
        record = pd.concat([record.iloc[-1:], temp, record.iloc[[0] + idx]], ignore_index=True)

        # record = pd.concat([record.iloc[np.arange(len(record) - 1, -1, -1)], temp], ignore_index=True)

        record = record.astype({'step': int})
        print(record)
    # Get the wt sequence, locations, and mutations
    df = pd.read_csv(data_fpath + 'rbd_dms.csv')
    tgt = df['fitness'].values
    # preprocess the structure just once
    collater = StructureCollater(SimpleCollater(PROTEIN_ALPHABET), n_connections=n_connections)
    tkn = Tokenizer(alphabet=PROTEIN_ALPHABET)
    if args.af2:
        pdb = 'rbd_dda7d_unrelaxed_rank_1_model_2'
        rbd_chain = 'B'
        ace_chain = 'C'
    else:
        pdb = '6M0J'
        rbd_chain = 'E'
        ace_chain = 'A'
    rbd_coords, rbd_wt, _ = parse_PDB(pdb_fpath + pdb + '.pdb', chain=rbd_chain)
    ace2_coords, ace2_wt, _ = parse_PDB(pdb_fpath + pdb + '.pdb', chain=ace_chain)
    wt = rbd_wt + ace2_wt
    src_ace2 = torch.tensor(tkn.tokenize(ace2_wt)).unsqueeze(0).to(device)
    coords = np.concatenate([rbd_coords, ace2_coords])

    coords = {
        'N': coords[:, 0],
        'CA': coords[:, 1],
        'C': coords[:, 2]
    }
    dist, omega, theta, phi = process_coords(coords)
    batch = [[wt, torch.tensor(dist, dtype=torch.float),
              torch.tensor(omega, dtype=torch.float),
              torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
    _, nodes, edges, connections, edge_mask = collater(batch)
    nodes = nodes.to(device)
    edges = edges.to(device)
    connections = connections.to(device)
    edge_mask = edge_mask.to(device)
    # full_rbd = rbd_wt
    # offset = 102
    full_rbd = parse_fasta(data_fpath + 'P0DTC2.fasta')[0]
    offset = 434
    # if args.no_gnn:
    #     wt = parse_fasta(data_fpath + 'P0DTC2.fasta')[0]
    #     offset = 434
    # else:
    #     offset = 102
    for i, row in df.iterrows():
        assert full_rbd[row['position'] + offset] == row['wildtype_aa']

    for idx, row in record.iterrows():
        step = str(int(row['step']))
        wd = args.map_dir + 'checkpoint' + step + '.tar'
        if args.cnn:
            d_cnn = 1280
            n_cnn_layers = 56
            kernel_size = 5
            r = 128
            slim = False
            d_embed = 8
            activation = 'gelu'
            causal = False
            pad_idx = PROTEIN_ALPHABET.index(PAD)
            cnn = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
                                slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)

        if not args.no_gnn:
            gnn = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                               d_model, num_decoder_layers=n_layers,
                                               dropout=dropout, use_mpnn=use_mpnn,
                                                 pe=False, one_hot_src=(not args.cnn)).to(device)
        if int(row['step'] != 0):
            sd = torch.load(wd, map_location=device)
            if args.cnn and args.no_gnn:
                sd = sd['model_state_dict']
                sd = {k.split('module.')[1]: v for k, v in sd.items()}
                cnn.load_state_dict(sd)
                cnn = cnn.to(device).eval()
            elif args.cnn:
                gnn_sd = sd['decoder_state_dict']
                gnn.load_state_dict(gnn_sd)
                gnn = gnn.to(device).eval()
                cnn_sd = sd['encoder_state_dict']
                cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
                cnn.load_state_dict(cnn_sd)
                cnn = cnn.to(device).eval()
            else:
                sd = sd['model_state_dict']
                sd = {k.split('module.')[1]: v for k, v in sd.items()}
                gnn.load_state_dict(sd)
                gnn = gnn.to(device).eval()

        pseudo_likelihood = {}
        pre = np.zeros(len(df))
        if args.cnn and not args.no_gnn:
            with torch.no_grad():
                enc_ace2 = cnn(src_ace2)
        else:
            enc_ace2 = src_ace2
        for i, row in tqdm(df.iterrows(), total=len(df)):
            pos = (row['position'] + offset,)
            aas = [PROTEIN_ALPHABET.index(row['mutant_aa'])]
            wt_idx = [PROTEIN_ALPHABET.index(row['wildtype_aa'])]
            assert full_rbd[pos[0]] == row['wildtype_aa']
            seq = ''
            seq += full_rbd[0:pos[0]]
            seq += MASK
            seq += full_rbd[pos[0] + 1:]
            src_rbd = torch.tensor(tkn.tokenize(seq)).unsqueeze(0).to(device)
            if args.use_wt:
                if pos not in pseudo_likelihood:
                    with torch.no_grad():
                        if args.cnn:
                            src_rbd = cnn(src_rbd)
                            if args.no_gnn:
                                output = F.log_softmax(src_rbd, dim=-1)
                                pos_ = pos
                        if not args.no_gnn:
                            # src = src.cat([src_rbd, enc_ace2], dim=1)
                            src = torch.cat([src_rbd[:1, 332: 332 + len(rbd_wt)], enc_ace2], dim=1)
                            pos_ = (pos[0] - 332, )
                            output = F.log_softmax(gnn(nodes, edges, connections, src, edge_mask), dim=-1)
                        pl = output[0, pos_].cpu()
                        wt_pl = output[torch.zeros(len(pos)).long(), pos_, wt_idx].cpu()
                        pseudo_likelihood[pos] = (pl - wt_pl.unsqueeze(-1)).numpy()
                pre[i] = pseudo_likelihood[pos][np.arange(len(pos)), np.array(aas)].sum()
            else:
                with torch.no_grad():
                    if args.cnn:
                        src = cnn(src)
                        if args.no_gnn:
                            output = F.log_softmax(src, dim=-1)
                    if not args.no_gnn:
                        output = F.log_softmax(gnn(nodes, edges, connections, src, edge_mask), dim=-1) # 1, ell, t
                        pos = (pos[0] - 332,)
                    # Slice out correct log probs
                    src = torch.tensor(tkn.tokenize(seq)).unsqueeze(0).to(device)
                    _, ell = src.shape
                    output = output[0, np.arange(ell), src[0]]
                    # Add them up
                    pre[i] = output.mean().cpu().item()

        rho = spearmanr(pre, tgt)
        record.loc[idx, 'rbd_rho'] = rho.correlation
        print(record)
        record.to_csv(args.out_fpath + out_fname, index=False)
        # Save predictions
        df['prediction'] = pre
        df.to_csv(args.out_fpath + 'predictions%s.csv' %step, index=False)

if __name__ == '__main__':
    main()