import argparse
import json
from tqdm import tqdm
import os
import pathlib

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from apex import amp
from apex.optimizers import FusedAdam
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

from sequence_models.collaters import TAPECollater, Seq2PropertyCollater, SimpleCollater, StructureCollater
from sequence_models.datasets import TAPEDataset, CSVDataset
from sequence_models.constants import PROTEIN_ALPHABET, PAD
from sequence_models.convolutional import ByteNetLM
from sequence_models.gnn import BidirectionalStruct2SeqDecoder
from sequence_models.structure import Attention1d
from sequence_models.utils import warmup, Tokenizer
from sequence_models.pdb_utils import process_coords, parse_PDB
from sequence_models.flip_utils import load_flip_data


class Model(nn.Module):

    def __init__(self, d_model, d_out, dropout=0.0):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.attention = Attention1d(d_model)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(d_model, 1280)
        self.linear = nn.Linear(1280, d_out)

    def forward(self, e, input_mask=None):
        e = self.ln(e)
        attended = self.attention(e, input_mask=input_mask)
        hidden = self.hidden(self.activation(attended))
        return self.linear(self.dropout(self.activation(hidden)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--task', type=str, required=False)
    parser.add_argument('--cat', action='store_true')
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--learning_rate', default=None, type=float)
    parser.add_argument('--cnn', action='store_true')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--af2', action='store_true')
    args = parser.parse_args()

    train(args)

def train(args):
    _ = torch.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu)
    device = torch.device('cuda:' + str(args.gpu))

    ## Grab data
    task = args.task
    lr = 1e-4
    loss_func = nn.MSELoss()
    test_name = 'test'
    d_out = 1
    if task == 'fluorescence':
        panel = 'tape'
        pdb = '5N9O'
        af2 = 'gfp_579fa_unrelaxed_rank_1_model_4'
        start = 0
        if args.af2:
            end = 237
        else:
            end = 235
        if args.finetune and args.cnn:
            batch_size = 32
        elif args.finetune:
            batch_size = 48
        else:
            batch_size = 128
    elif task == 'rmanod':
        panel = 'rmanod'
        pdb = '6WK3'
        if args.cat:
            batch_size = 16
        else:
            batch_size = 32
        start = 0
        end = 145
    else:
        panel = 'flip'
        flip_dataset = task.split('_')[0]
        flip_split = '_'.join(task.split('_')[1:])
        if flip_dataset == 'aav':
            pdb = '1LP3'
            if args.finetune:
                batch_size = 8
            else:
                batch_size = 64
        elif flip_dataset == 'gb1':
            pdb = '2gi9'
            af2 = 'gb1_a60fb_unrelaxed_rank_1_model_5'
            start = 0
            end = 56
            if 'one' in task:
                lr = 1e-4
                batch_size = 5
            elif 'two' in task:
                batch_size = 16
            elif args.finetune:
                batch_size = 16
            else:
                batch_size = 512
    try:
        data_fpath = os.getenv('PT_DATA_DIR') + '/' + panel + '/'
        pdb_path = os.getenv('PT_DATA_DIR') + '/pdb/'
        af2_path = os.getenv('PT_DATA_DIR') + '/af2/'

    except:
        home = str(pathlib.Path.home())
        data_fpath = home + '/data/' + panel + '/'
        pdb_path = home + '/data/pdb/'
        af2_path = home + '/data/af2/'
    if panel == 'tape':
        data_fpath += task + '/'
        # grab data based on task
        ds_train = TAPEDataset(data_fpath, task, 'train')  # max_len only matters for contacts
        ds_valid = TAPEDataset(data_fpath, task, 'valid')
        ds_test = TAPEDataset(data_fpath, task, test_name)
        collate_fn = TAPECollater(PROTEIN_ALPHABET)
        num_workers = 0
    elif panel == 'rmanod':
        df = pd.read_csv(data_fpath + 'protabank.csv')
        df['sequence'] = df['Sequence']
        is_test = [d.count('+') > 3 for d in df['Description'].values]
        test_idx = np.array([i for i, t in enumerate(is_test) if t])
        train_idx = np.array([i for i, t in enumerate(is_test) if not t])
        n_valid = 40
        np.random.seed(args.seed)
        np.random.shuffle(test_idx)
        valid_idx = test_idx[:n_valid]
        test_idx = test_idx[n_valid:]
        df.loc[train_idx, 'split'] = 'train'
        df.loc[valid_idx, 'split'] = 'valid'
        df.loc[test_idx, 'split'] = 'test'
        ds_train = CSVDataset(df=df, split='train', outputs=['Data'])
        ds_valid = CSVDataset(df=df, split='valid', outputs=['Data'])
        ds_test = CSVDataset(df=df, split='test', outputs=['Data'])
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 16
    else:
        ds_train, ds_valid, ds_test = load_flip_data(data_fpath, flip_dataset, flip_split, max_len=2048)
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 16
    dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collate_fn,
                          num_workers=num_workers, shuffle=True)
    dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
    dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)

    # Make the GNN
    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_model = config['d_model']
    node_features = 10
    edge_features = 11
    dropout = config['dropout']
    use_mpnn = config['use_mpnn']
    n_layers = config['n_layers']
    n_connections = config['n_connections']

    # Process the structure
    collater = StructureCollater(SimpleCollater(PROTEIN_ALPHABET), n_connections=n_connections)
    if args.af2:
        coords, wt, _ = parse_PDB(af2_path + af2 + '.pdb')
    else:
        coords, wt, _ = parse_PDB(pdb_path + pdb + '.pdb')
    coords = {
        'N': coords[:, 0],
        'CA': coords[:, 1],
        'C': coords[:, 2]
    }
    dist, omega, theta, phi = process_coords(coords)
    if task == 'fluorescence' and not args.af2:
        dist = dist[2:, 2:]
        omega = omega[2:, 2:]
        theta = theta[2:, 2:]
        phi = phi[2:, 2:]
        wt = wt[2:]
    if task == 'rmanod':
        dist = dist[1:, 1:]
        omega = omega[1:, 1:]
        theta = theta[1:, 1:]
        phi = phi[1:, 1:]
        wt = wt[1:]
    batch = [[wt, torch.tensor(dist, dtype=torch.float),
              torch.tensor(omega, dtype=torch.float),
              torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
    _, nodes, edges, connections, edge_mask = collater(batch)
    nodes = nodes.to(device)
    edges = edges.to(device)
    connections = connections.to(device)
    edge_mask = edge_mask.to(device)
    # Prep results
    out_fname = 'metrics.csv'
    if out_fname in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + out_fname)
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        record.columns = ['step', 'loss']
        record['epoch'] = record.index
        record['epoch'] += 1
        temp = pd.DataFrame(np.array([[0, np.nan, 0]]), columns=record.columns)
        min_idx = record.sort_values('loss').index
        record = record.loc[:min_idx[0]]
        first_loss = record.iloc[0, 1]
        last_loss = record.iloc[-1, 1]
        bins = np.linspace(first_loss, last_loss, 10)[1:-1]
        j = 0
        idx = []
        for i, row in record.iterrows():
            if row['loss'] < bins[j]:
                j += 1
                idx.append(i)
                if j == len(bins):
                    break
        record = pd.concat([record.iloc[-1:], temp, record.iloc[[0] + idx]], ignore_index=True)
        record = record.astype({'step': int})

    epochs = 500
    if task == 'gb1_one_vs_rest':
        epochs = 500
    if args.finetune:
        loss_col = task + '_ft_loss'
        metr_col = task + '_ft_metric'
    else:
        loss_col = task + '_loss'
        metr_col = task + '_metric'


    def step(cnn, gnn, decoder, batch, train=True, return_values=False):
        src, tgt, input_mask = batch
        src = src.to(device)[:, start: end]
        tgt = tgt.to(device)
        b = len(src)
        bnodes = nodes.repeat(b, 1, 1)
        bedges = edges.repeat(b, 1, 1, 1)
        bconnections = connections.repeat(b, 1, 1)
        bmask = edge_mask.repeat(b, 1, 1, 1)
        input_mask = (src != PROTEIN_ALPHABET.index(PAD)).float().unsqueeze(-1)
        if cnn is not None:
            if args.finetune:
                e_cnn = cnn.embedder(src, input_mask=input_mask)
                src = cnn.decoder(e_cnn)
            else:
                with torch.no_grad():
                    e_cnn = cnn.embedder(src, input_mask=input_mask)
                    src = cnn.decoder(e_cnn)
        if args.finetune:
            e = gnn(bnodes, bedges, bconnections, src, bmask, decoder=False)
        else:
            with torch.no_grad():
                e = gnn(bnodes, bedges, bconnections, src, bmask, decoder=False)
        if args.cat:
            e = torch.cat([e_cnn, e], dim=-1)
        outputs = decoder(e, input_mask=input_mask)
        loss = loss_func(outputs, tgt)
        locations = len(tgt)
        if train:
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            scheduler.step()
        if return_values:
            return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu()
        else:
            return loss.item(), locations


    def epoch(cnn, gnn, decoder, current_step=0):
        if cnn is not None:
            cnn = cnn.train()
        gnn = gnn.train()
        decoder = decoder.train()
        t = 'Training:'
        losses = []
        ns = []
        n_seen = 0
        n_total = len(ds_train)
        for i, batch in enumerate(dl_train):
            new_loss, new_n = step(cnn, gnn, decoder, batch, train=True, return_values=False)
            losses.append(new_loss * new_n)
            ns.append(new_n)
            n_seen += len(batch[0])
            total_n = sum(ns)
            if total_n == 0:
                rloss = 0
            else:
                rloss = sum(losses) / total_n
            nsteps = current_step + i + 1
            print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f'
                  % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss),
                  end='')
        return i, rloss


    def test_epoch(cnn, gnn, decoder, dl):
        gnn = gnn.eval()
        decoder = decoder.eval()
        if cnn is not None:
            cnn = cnn.eval()
        with torch.no_grad():
            losses = []
            ns = []
            n_seen = 0
            pred = []
            tgt = []
            # src = []
            for i, batch in enumerate(dl):
                new_loss, new_n, p, s, t = step(cnn, gnn, decoder, batch, train=False, return_values=True)
                losses.append(new_loss * new_n)
                pred.append(p)
                tgt.append(t)
                # src.append(s)
                ns.append(new_n)
                n_seen += len(batch[0])
                total_n = sum(ns)

        test_loss = sum(losses) / total_n
        pred = torch.cat(pred)
        tgt = torch.cat(tgt)
        pred = pred.numpy()
        tgt = tgt.numpy()
        metric_func = spearmanr
        metric = metric_func(pred, tgt).correlation
        print('loss: %f' %test_loss, end='\t')
        print('metric: %f' %(metric))
        results = {
            'metric': metric,
            'loss': test_loss
        }
        return results



    for idx, row in tqdm(record.iterrows(), total=len(record)):
        print(record)
        wd = args.map_dir + 'checkpoint' + str(int(row['step'])) + '.tar'
        if loss_col in row:
            if not np.isnan(row[loss_col]):
                continue
        np.random.seed(args.seed + 100)
        if args.cnn:
            d_cnn = 1280
            n_cnn_layers = 56
            kernel_size = 5
            r = 128
            slim = False
            d_embed = 8
            activation = 'gelu'
            causal = False
            pad_idx = PROTEIN_ALPHABET.index(PAD)
            cnn = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
                            slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)

        else:
            cnn = None
        gnn = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                             d_model, num_decoder_layers=n_layers,
                                             dropout=dropout, use_mpnn=use_mpnn,
                                             pe=False, one_hot_src=(not args.cnn)).to(device)
        if int(row['step'] != 0):
            sd = torch.load(wd, map_location=device)
            if args.cnn:
                gnn_sd = sd['decoder_state_dict']
                gnn.load_state_dict(gnn_sd)
                gnn = gnn.to(device).eval()
                cnn_sd = sd['encoder_state_dict']
                cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
                cnn.load_state_dict(cnn_sd)
                cnn.embedder.embedder.weight = nn.Parameter(cnn.embedder.embedder.weight * 1.25)
                cnn = cnn.to(device).eval()
            else:
                sd = sd['model_state_dict']
                sd = {k.split('module.')[1]: v for k, v in sd.items()}
                gnn.load_state_dict(sd)
                gnn.W_s.weight = nn.Parameter(gnn.W_s.weight * 1.25)
                gnn = gnn.to(device).eval()
        if args.cat:
            decoder = Model(d_model + d_cnn, d_out, dropout=dropout)
        else:
            decoder = Model(d_model, d_out, dropout=dropout)
        decoder = decoder.to(device)
        if args.finetune and args.cnn:
            optimizer = FusedAdam(list(cnn.parameters()) + list(gnn.parameters()) + list(decoder.parameters()), lr=lr)
        elif args.finetune:
            optimizer = FusedAdam(list(gnn.parameters()) + list(decoder.parameters()), lr=lr)
        else:
            optimizer = FusedAdam(decoder.parameters(), lr=lr)
        if args.cnn:
            (cnn, gnn, decoder), optimizer = amp.initialize([cnn, gnn, decoder], optimizer, opt_level='O1')
        else:
            (gnn, decoder), optimizer = amp.initialize([gnn, decoder], optimizer, opt_level='O2')

        n_warmup = 1000
        scheduler = LambdaLR(optimizer, warmup(n_warmup))
        total_steps = 0
        best_valid_metric = -100.0
        best_valid_loss = 100.0
        patience = 5
        if task == 'gb1_one_vs_rest':
            patience = 20

        waiting = 0
        for e in range(epochs):
            ts, train_loss = epoch(cnn, gnn, decoder, current_step=total_steps)
            total_steps += ts
            nsteps = total_steps
            results = test_epoch(cnn, gnn, decoder, dl_valid)
            vloss = results['loss']
            vmetric = results['metric']
            # if task == 'gb1_one_vs_rest' and e < 35:
            #     continue
            save_me = False
            if vmetric > best_valid_metric and task != 'gb1_one_vs_rest':
                best_valid_metric = vmetric
                waiting = 0
                save_me = True
            if vloss < best_valid_loss:
                if task == 'gb1_one_vs_rest':
                    save_me = True
                best_valid_loss = vloss
                waiting = 0
            if vloss < train_loss and task != 'gb1_one_vs_rest':
                waiting = 0
            if args.out_fpath is not None and save_me:
                if args.cnn and args.finetune:
                    cnn_state_dict = cnn.state_dict()
                else:
                    cnn_state_dict = {}
                torch.save({
                    'step': nsteps,
                    'gnn_state_dict': gnn.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'cnn_state_dict': cnn_state_dict
                }, args.out_fpath + 'best.tar')
            else:
                waiting += 1
            if waiting == patience:
                break
        if args.out_fpath is not None:
            sd = torch.load(args.out_fpath + 'best.tar')
            if args.finetune:
                gnn.load_state_dict(sd['gnn_state_dict'])
                if args.cnn:
                    cnn.load_state_dict(sd['cnn_state_dict'])
            decoder.load_state_dict(sd['decoder_state_dict'])
            results = test_epoch(cnn, gnn, decoder, dl_test)
            record.loc[idx, metr_col] = results['metric']
            record.loc[idx, loss_col] = results['loss']
            record.to_csv(args.out_fpath + 'metrics.csv', index=False)




if __name__ == '__main__':
    main()