import argparse
import json
import os
from datetime import datetime
import pathlib

import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from apex.optimizers import FusedAdam
from torch.utils.data import DataLoader, RandomSampler
import torch.distributed as dist
import torch.nn.functional as F
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
import numpy as np

from sequence_models.gnn import StructEncoderDecoder, cat_neighbors_nodes, BidirectionalStruct2SeqDecoder
from sequence_models.convolutional import ByteNetLM
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK, START, STOP
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from sequence_models.datasets import UniRefDataset, TRRDataset
from sequence_models.collaters import StructureCollater, MLMCollater, SimpleCollater
from sequence_models.losses import MaskedCrossEntropyLoss
from sequence_models.metrics import MaskedAccuracy
from sequence_models.utils import transformer_lr, Tokenizer


home = str(pathlib.Path.home())


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--task', default='mlm')
    parser.add_argument('-w', '--weights_fpath', required=False)
    parser.add_argument('-f', '--freeze', action='store_true')
    parser.add_argument('--no_gnn', action='store_true')
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    parser.add_argument('--esm', action='store_true')
    parser.add_argument('--logsoftmax', action='store_true')
    parser.add_argument('--full', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--test2', default=None)
    parser.add_argument('--r', default=128, type=int)
    parser.add_argument('--k', default=5, type=int)
    parser.add_argument('--d_cnn', default=1280, type=int)
    parser.add_argument('--n_cnn', default=56, type=int)
    parser.add_argument('--activation', default='gelu')
    parser.add_argument('--slim', default=False)
    parser.add_argument('--dropout', default=0.0, type=float)
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    parser.add_argument('-off', '--offset', default=0, type=int,
                        help='Number of GPU devices to skip.')
    parser.add_argument('--dataset', default=None)
    args = parser.parse_args()
    if args.esm:
        args.freeze = True
    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8881'
    mp.spawn(train, nprocs=args.gpus, args=(args,))


def train(gpu, args):
    _ = torch.manual_seed(0)
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=args.world_size,
        rank=rank)
    torch.cuda.set_device(gpu + args.offset)
    device = torch.device('cuda:' + str(gpu + args.offset))

    n_tokens = len(PROTEIN_ALPHABET)
    d_cnn = args.d_cnn
    n_cnn_layers = args.n_cnn
    kernel_size = args.k
    r = args.r
    slim = args.slim
    activation = args.activation
    node_features = 10
    edge_features = 11
    dropout = args.dropout
    use_mpnn = True
    n_structure_layers = 4
    n_connections = 30
    d_embed = 8
    d_gnn = 256
    bucket_size = 1000
    max_tokens = 6000
    max_batch_size = 100
    epochs = 1000
    lr = 1e-3
    opt_level = 'O2'
    warmup_steps = 1000
    train_steps = 1e10
    max_len = 1024
    pad_idx = PROTEIN_ALPHABET.index(PAD)
    start_idx = PROTEIN_ALPHABET.index(START)
    stop_idx = PROTEIN_ALPHABET.index(STOP)
    if args.esm or args.task == 'lm':
        max_len -= 2
    try:
        data_dir = os.getenv('PT_DATA_DIR') + '/'
        ptjob = True
    except:
        data_dir = home + '/data/'
        ptjob = False
    if args.dataset is not None:
        dataset = args.dataset
    elif ptjob:
        dataset = 'cath'
    else:
        dataset = "uniclust/cath"
    if args.task == 'mlm':
        collater = MLMCollater(PROTEIN_ALPHABET)
    else:
        collater = SimpleCollater(PROTEIN_ALPHABET, pad=True, backwards=False)
    if dataset != 'esm':
        collater = StructureCollater(collater, n_connections=n_connections)
    if dataset == 'esm':
        data_dir = data_dir + 'esm/'
        with open(data_dir + 'splits.json') as f:
            splits = json.load(f)
        metadata = np.load(data_dir + 'lengths_and_offsets.npz')
        train_idx = splits['train']
        len_train = np.minimum(metadata['ells'][train_idx], max_len)
        ds_train = UniRefDataset(data_dir, 'train', structure=False, pdb=False,
                                 p_drop=0.0, max_len=max_len)
    elif dataset != 'trr':
        data_dir = data_dir + dataset + '/'
        with open(data_dir + 'splits.json') as f:
            splits = json.load(f)
        metadata = np.load(data_dir + 'lengths_and_offsets.npz')
        train_idx = splits['train']
        len_train = np.minimum(metadata['ells'][train_idx], max_len)
        ds_train = UniRefDataset(data_dir, 'train', structure=True, pdb=True,
                                 p_drop=0.0, max_len=max_len)
    else:
        ds_train = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'train', bin=False,
                              return_msa=False, max_len=max_len, untokenize=True)
        len_train = np.load(data_dir + 'trrosetta/trrosetta/train_lengths.npz')['ells']
        len_train = np.minimum(len_train, max_len)
    train_sortish_sampler = SortishSampler(len_train, bucket_size, num_replicas=args.world_size, rank=rank)
    train_sampler = ApproxBatchSampler(train_sortish_sampler, max_tokens, max_batch_size, len_train)
    dl_train = DataLoader(dataset=ds_train,
                          batch_sampler=train_sampler,
                          num_workers=8,
                          collate_fn=collater)
    if rank == 0:
        if dataset == 'esm':
            with open(data_dir + 'splits.json') as f:
                splits = json.load(f)
            metadata = np.load(data_dir + 'lengths_and_offsets.npz')
            test_idx = splits['valid']
            len_test = np.minimum(metadata['ells'][test_idx], max_len)
            ds_test = UniRefDataset(data_dir, 'valid', structure=False, pdb=False,
                                    p_drop=0.0, max_len=max_len)
            test_sortish_sampler = RandomSampler(len_test)
            max_tokens = 4 * max_tokens
            test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test)
            dl_test = DataLoader(dataset=ds_test,
                                 batch_sampler=test_sampler,
                                 num_workers=8,
                                 collate_fn=collater)
            ds_valid = ds_test
            len_valid = len_test
        elif dataset != 'trr':
            with open(data_dir + 'splits.json') as f:
                splits = json.load(f)
            metadata = np.load(data_dir + 'lengths_and_offsets.npz')
            valid_idx = splits['valid']
            len_valid = np.minimum(metadata['ells'][valid_idx], max_len)
            ds_valid = UniRefDataset(data_dir, 'valid', structure=True, pdb=True,
                                     p_drop=0.0, max_len=max_len)
            test_idx = splits['test']
            len_test = np.minimum(metadata['ells'][test_idx], max_len)
            ds_test = UniRefDataset(data_dir, 'test', structure=True, pdb=True,
                                     p_drop=0.0, max_len=max_len)
            if args.test2 is not None:
                with open(home + '/workspace/data/esm/cath_splits.json') as f:
                    sp = json.load(f)
                ds_test.idx = sp[args.test2]
                len_test = np.minimum(metadata['ells'][ds_test.idx], max_len)
            test_sortish_sampler = SortishSampler(len_test, bucket_size)
            test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test)
            dl_test = DataLoader(dataset=ds_test,
                                 batch_sampler=test_sampler,
                                 num_workers=8,
                                 collate_fn=collater)
        else:
            ds_valid = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'valid', bin=False, return_msa=False,
                                  max_len=max_len, untokenize=True)
            len_valid = np.load(data_dir + 'trrosetta/trrosetta/valid_lengths.npz')['ells']
            len_valid = np.minimum(len_valid, max_len)
            dl_test = None
        valid_sortish_sampler = SortishSampler(len_valid, bucket_size)
        valid_sampler = ApproxBatchSampler(valid_sortish_sampler, max_tokens, max_batch_size, len_valid)
        dl_valid = DataLoader(dataset=ds_valid,
                              batch_sampler=valid_sampler,
                              num_workers=8,
                              collate_fn=collater)

    # Initiate model
    if args.task == 'mlm':
        decoder = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                                 d_gnn, num_decoder_layers=n_structure_layers,
                                                 dropout=dropout, use_mpnn=use_mpnn,
                                                 one_hot_src=False).to(device)
    else:
        decoder = StructEncoderDecoder(n_tokens, node_features, edge_features, d_gnn, src_node=True,
                                       num_encoder_layers=n_structure_layers - 1, num_decoder_layers=1,
                                       use_mpnn=use_mpnn, one_hot_src=False, dropout=dropout).to(device)
    if args.esm:
        from esm.pretrained import load_model_and_alphabet
        encoder, alphabet = load_model_and_alphabet(home + "/.cache/torch/checkpoints/esm1b_t33_650M_UR50S.pt")
        tokenizer = Tokenizer(PROTEIN_ALPHABET)
        to_esm = {}
        for p in PROTEIN_ALPHABET:
            k = tokenizer.a_to_t[p]
            if p == PAD:
                to_esm[k] = alphabet.padding_idx
            elif p in alphabet.tok_to_idx:
                to_esm[k] = alphabet.tok_to_idx[p]
            elif p == MASK:
                to_esm[k] = alphabet.mask_idx
            elif p == START:
                to_esm[k] = alphabet.cls_idx
            elif p == STOP:
                to_esm[k] = alphabet.eos_idx
            else:
                to_esm[k] = alphabet.unk_idx
        dim_reorder = torch.tensor([to_esm[k] for k in range(len(PROTEIN_ALPHABET))])
        encoder = encoder.to(device)
    else:
        if args.task == 'mlm':
            causal = False
        else:
            causal = True
        encoder = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
                            slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)
    if args.weights_fpath is not None:
        # if ptjob:
        #     args.weights_fpath = os.getenv('PT_DATA_DIR') + '/' + args.weights_fpath
        print('Loading weights from ' + args.weights_fpath + '...')
        sd = torch.load(args.weights_fpath, map_location=device)
        # if args.no_gnn:
        cnn_sd = sd['model_state_dict']
        cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
        encoder.load_state_dict(cnn_sd)
        # else:
        #     gnn_sd = sd['decoder_state_dict']
        #     decoder.load_state_dict(gnn_sd)
        #     cnn_sd = sd['encoder_state_dict']
        #     cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
        #     encoder.load_state_dict(cnn_sd)

    if args.esm:
        optimizer = FusedAdam(list(decoder.parameters()), lr=lr)
        decoder, optimizer = amp.initialize(decoder, optimizer, opt_level=opt_level)
        decoder = DDP(decoder)
    else:
        if args.freeze:
            optimizer = FusedAdam(list(decoder.parameters()), lr=lr)
            (encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level)
        else:
            optimizer = FusedAdam([{'params': encoder.parameters(), 'lr': 1e-4},
                                  {'params': decoder.parameters(), 'lr': lr}])
            (encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level)
            decoder = DDP(decoder)
        encoder = DDP(encoder)
    scheduler = LambdaLR(optimizer, transformer_lr(warmup_steps))
    loss_func = MaskedCrossEntropyLoss()
    accu_func = MaskedAccuracy()

    def epoch(encoder, decoder, train, current_step=0):
        start_time = datetime.now()
        if train:
            if not args.freeze:
                encoder = encoder.train()
            if args.freeze:
                encoder = encoder.train()
            if decoder is not None:
                decoder = decoder.train()
            loader = dl_train
            t = 'Training:'
        else:
            encoder = encoder.eval()
            if decoder is not None:
                decoder = decoder.eval()
            loader = dl_valid
            t = 'Validating:'
        losses = []
        accus = []
        ns = []
        chunk_time = datetime.now()
        n_seen = 0
        if train:
            n_total = len(ds_train) // args.world_size
        else:
            n_total = len(ds_valid)
        for i, batch in enumerate(loader):
            new_loss, new_accu, new_n = step(encoder, decoder, batch, train)
            losses.append(new_loss * new_n)
            accus.append(new_accu * new_n)
            ns.append(new_n)
            n_seen += len(batch[0])
            total_n = sum(ns)
            rloss = sum(losses) / total_n
            raccu = sum(accus) / total_n
            if train:
                nsteps = current_step + i + 1
            else:
                nsteps = i
            print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %.4f accu = %.4f'
                  % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss, raccu),
                  end='')
            if train:
                losses = losses[-999:]
                accus = accus[-999:]
                ns = ns[-999:]
                # if (nsteps) % train_steps == 0 and rank == 0:
                #     print('\nTraining complete in ' + str(datetime.now() - chunk_time))
                #     with torch.no_grad():
                #         _ = epoch(encoder, decoder, False, current_step=nsteps)
                #     chunk_time = datetime.now()
        if not train:
            print('\nValidation complete in ' + str(datetime.now() - start_time))
            return rloss
        elif rank == 0:
            print('\nEpoch complete in ' + str(datetime.now() - start_time))
        return i

    def step(encoder, decoder, batch, train):
        if args.dataset == 'esm':
            src, tgt, mask = batch
            src = src.to(device)
            tgt = tgt.to(device)
            mask = mask.to(device)
        else:
            if args.task == 'lm':
                src, nodes, edges, connections, edge_mask = batch
                tgt = src.detach().clone()
                mask = (src != PROTEIN_ALPHABET.index(PAD)).float()
                n, ell = src.shape
                starts = torch.zeros(n, 1) + start_idx
                starts = starts.long()
                src = torch.cat([starts, src], dim=-1)
            else:
                src, tgt, mask, nodes, edges, connections, edge_mask = batch
            src = src.to(device)
            tgt = tgt.to(device)
            mask = mask.to(device)
            nodes = nodes.cuda()
            edges = edges.cuda()
            connections = connections.cuda()
            edge_mask = edge_mask.cuda()
        input_mask = (src != pad_idx).float().unsqueeze(-1)
        if args.esm:
            n, ell = src.shape
            esm_src = torch.zeros(n, ell + 2) + alphabet.padding_idx
            esm_src[:, 0] = alphabet.cls_idx
            tokenized = [[to_esm[s.item()] for s in sr if s != pad_idx] + [alphabet.eos_idx] for sr in src]
            ells = []
            for i, t in enumerate(tokenized):
                el = len(t)
                ells.append(el - 1)
                esm_src[i, 1:el + 1] = torch.tensor(t)
            esm_src = esm_src.to(device).long()
            with torch.no_grad():
                e = encoder(esm_src)['logits']
                embeddings = torch.zeros(n, ell, 33)
                embeddings = embeddings.to(device)
                for i, (ee, el) in enumerate(zip(e, ells)):
                    embeddings[i, :el, :] = ee[1:el + 1]
                embeddings = embeddings[:, :, dim_reorder]
        else:
            if args.freeze:
                with torch.no_grad():
                    embeddings = encoder(src, input_mask=input_mask)
            else:
                embeddings = encoder(src, input_mask=input_mask)
            # slice out the extra position
            if args.task == 'lm':
                embeddings = embeddings[:, :-1, :]
        if (args.esm and args.test) or args.no_gnn:
            outputs = embeddings
        else:
            if args.logsoftmax:
                embeddings = F.log_softmax(embeddings, dim=-1)
            outputs = decoder(nodes, edges, connections, embeddings, edge_mask)
        loss = loss_func(outputs, tgt, mask)
        accu = accu_func(outputs, tgt, mask)
        if train:
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            scheduler.step()
        return loss.item(), accu.item(), mask.sum().item()

    total_steps = 0
    n_parameters = sum(p.numel() for p in decoder.parameters())
    if rank == 0:
        print('%d model parameters' %n_parameters)
        print('%d training sequences' %len(len_train))
        print('%d validation sequences' %len(len_valid))
    if args.test:
        e = 0
        with torch.no_grad():
            dl_valid = dl_test
            _ = epoch(encoder, decoder, False)
        return
    best_valid_loss = 100
    patience = 20
    min_epochs = 500
    waiting = 0
    m_file = args.out_fpath + 'metrics.csv'
    best_path = args.out_fpath + 'best.pt'
    for e in range(epochs):
        train_sortish_sampler.set_epoch(e)
        total_steps += epoch(encoder, decoder, True, current_step=total_steps)
        if rank == 0:
            nsteps = total_steps
            model_path = args.out_fpath + 'checkpoint%d.tar' % nsteps
            torch.save({
                'step': nsteps,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, model_path)
            with torch.no_grad():
                vloss = epoch(encoder, decoder, False, current_step=total_steps)
            with open(m_file, 'a') as f:
                f.write(str(nsteps))
                f.write(',')
                f.write(str(vloss))
                f.write('\n')
            if vloss < best_valid_loss:
                best_valid_loss = vloss
                waiting = 0
                best_epoch = e + 1
                best_path = model_path
                # save_me = {'step': nsteps, 'valid_loss': vloss, 'epoch': e + 1}
                # save_me['encoder_state_dict'] = encoder.state_dict()
                # save_me['decoder_state_dict'] = decoder.state_dict()
                # save_me['optimizer_state_dict'] = optimizer.state_dict()
                # torch.save(save_me, best_path)
            else:
                waiting += 1
            if waiting >= patience and e > min_epochs:
                break
    if rank == 0 and dl_test is not None:
        print('Loading checkpoint from epoch %d and testing...' %best_epoch)
        sd = torch.load(best_path)
        if not args.esm:
            encoder.load_state_dict(sd['encoder_state_dict'])
        encoder = encoder.eval()
        decoder.load_state_dict(sd['decoder_state_dict'])
        decoder = decoder.eval()
        with torch.no_grad():
            dl_valid = dl_test
            _ = epoch(encoder, decoder, False)


if __name__ == '__main__':
    main()
