import argparse
import json
import os
from datetime import datetime
import pathlib

import mlflow
import torch
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import LambdaLR
from apex.optimizers import FusedAdam
from torch.utils.data import DataLoader, ConcatDataset
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
import numpy as np

from sequence_models.gnn import StructEncoderDecoder, BidirectionalStruct2SeqDecoder
from sequence_models.structure import StructureConditionedBytenet
from sequence_models.constants import PROTEIN_ALPHABET, PAD
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from sequence_models.datasets import UniRefDataset, TRRDataset
from sequence_models.collaters import StructureCollater, SimpleCollater, LMCollater, MLMCollater
from sequence_models.losses import MaskedCrossEntropyLoss
from sequence_models.metrics import MaskedAccuracy
from sequence_models.utils import warmup, transformer_lr


home = str(pathlib.Path.home())


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath')
    parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--pe', action='store_true')
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    parser.add_argument('--n_layers', default=None, type=int)
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    parser.add_argument('-off', '--offset', default=0, type=int,
                        help='Number of GPU devices to skip.')
    parser.add_argument('--task', default=None)
    parser.add_argument('--dataset', default=None)
    parser.add_argument('--model_type', default=None)
    args = parser.parse_args()
    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'
    mp.spawn(train, nprocs=args.gpus, args=(args,))


def train(gpu, args):
    _ = torch.manual_seed(0)
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=args.world_size,
        rank=rank)
    torch.cuda.set_device(gpu + args.offset)
    device = torch.device('cuda:' + str(gpu + args.offset))

    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    if args.model_type is None:
        model_type = config['model_type']
    else:
        model_type = args.model_type
    if model_type == 'gnn':
        d_model = config['d_model']
        node_features = 10
        edge_features = 11
        dropout = config['dropout']
        use_mpnn = config['use_mpnn']
        if args.n_layers is not None:
            config['n_layers'] = args.n_layers
        n_layers = config['n_layers']
        n_connections = config['n_connections']
    elif model_type == 'cnn':
        d_embed = config['d_embed']
        d_model = config['d_model']
        n_layers = config['n_layers']
        kernel_size = config['kernel_size']
        r = config['r']
        d_con = config['d_con']
        r_c = config['r_c']
        kernel_size_c = config['kernel_size_c']
        n_c_layers = config['n_c_layers']
    bucket_size = config['bucket_size']
    max_tokens = config['max_tokens']
    max_batch_size = config['max_batch_size']
    epochs = config['epochs']
    lr = config['lr']
    opt_level = config['opt_level']
    warmup_steps = config['warmup_steps']
    train_steps = config['train_steps']
    drop_structure = config['drop_structure']
    if 'max_len' in config:
        max_len = config['max_len']
    else:
        max_len = 2048

    if args.task is not None:
        config['task'] = args.task
    if args.dataset is not None:
        config['dataset'] = args.dataset
    try:
        data_dir = os.getenv('PT_DATA_DIR') + '/'
        ptjob = True
    except:
        data_dir = home + '/data/'
        ptjob = False
    bw = config['task'] == 'rlm'
    if config['task'] == 'mlm':
        direction = 'bidirectional'
        collater = MLMCollater(PROTEIN_ALPHABET)
    elif model_type == 'gnn':
        direction = 'forward'
        collater = SimpleCollater(PROTEIN_ALPHABET, pad=True, backwards=False)
    else:
        collater = LMCollater(PROTEIN_ALPHABET, backwards=bw)
    collater = StructureCollater(collater, n_connections=n_connections)
    if config['dataset'] != 'trr':
        data_dirs = [data_dir + dataset + '/' for dataset in config['dataset']]
        pdbs = ['cath' in dd for dd in data_dirs]
        # build datasets, samplers, and loaders
        datasets = []
        lengths = []
        for data_dir, pdb in zip(data_dirs, pdbs):
            with open(data_dir + 'splits.json') as f:
                splits = json.load(f)
            metadata = np.load(data_dir + 'lengths_and_offsets.npz')
            train_idx = splits['train']
            len_train = np.minimum(metadata['ells'][train_idx], max_len)
            if 'cath' not in data_dir:
                drop = 1.0
            else:
                drop = drop_structure
            ds_train = UniRefDataset(data_dir, 'train', structure=True, pdb=pdb,
                                     p_drop=drop, max_len=max_len)
            lengths.append(len_train)
            datasets.append(ds_train)
        ds_train = ConcatDataset(datasets)
        len_train = np.concatenate(lengths)
    else:
        ds_train = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'train', bin=False,
                              return_msa=False, max_len=max_len, untokenize=True)
        len_train = np.load(data_dir + 'trrosetta/trrosetta/train_lengths.npz')['ells']
        len_train = np.minimum(len_train, max_len)
    train_sortish_sampler = SortishSampler(len_train, bucket_size, num_replicas=args.world_size, rank=rank)
    train_sampler = ApproxBatchSampler(train_sortish_sampler, max_tokens, max_batch_size, len_train)
    dl_train = DataLoader(dataset=ds_train,
                          batch_sampler=train_sampler,
                          num_workers=16,
                          collate_fn=collater)
    if rank == 0:
        if config['dataset'] != 'trr':
            if 'cath' not in data_dir:
                data_dir += 'cath/'
            with open(data_dir + 'splits.json') as f:
                splits = json.load(f)
            metadata = np.load(data_dir + 'lengths_and_offsets.npz')
            valid_idx = splits['valid']
            len_valid = np.minimum(metadata['ells'][valid_idx], max_len)
            ds_valid = UniRefDataset(data_dir, 'valid', structure=True, pdb=True,
                                     p_drop=drop_structure, max_len=max_len)
            test_idx = splits['test']
            len_test = np.minimum(metadata['ells'][test_idx], max_len)
            ds_test = UniRefDataset(data_dir, 'test', structure=True, pdb=True,
                                     p_drop=drop_structure, max_len=max_len)
            test_sortish_sampler = SortishSampler(len_test, bucket_size)
            test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test)
            dl_test = DataLoader(dataset=ds_test,
                                  batch_sampler=test_sampler,
                                  num_workers=8,
                                  collate_fn=collater)
        else:
            ds_valid = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'valid', bin=False, return_msa=False,
                                  max_len=max_len, untokenize=True)
            len_valid = np.load(data_dir + 'trrosetta/trrosetta/valid_lengths.npz')['ells']
            len_valid = np.minimum(len_valid, max_len)
            dl_test = None
        valid_sortish_sampler = SortishSampler(len_valid, bucket_size)
        valid_sampler = ApproxBatchSampler(valid_sortish_sampler, max_tokens, max_batch_size, len_valid)
        dl_valid = DataLoader(dataset=ds_valid,
                              batch_sampler=valid_sampler,
                              num_workers=8,
                              collate_fn=collater)

    # Initiate model
    if direction == 'bidirectional':
        model = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
                                               d_model, num_decoder_layers=n_layers,
                                               dropout=dropout, use_mpnn=use_mpnn, pe=args.pe).to(device)
    else:
        model = StructEncoderDecoder(n_tokens, node_features, edge_features,
                                     d_model, num_encoder_layers=n_layers - 1, num_decoder_layers=1,
                                     direction=direction, dropout=dropout, use_mpnn=use_mpnn).to(device)
    optimizer = FusedAdam(model.parameters(), lr=lr)
    model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    model = DDP(model)
    scheduler = LambdaLR(optimizer, transformer_lr(warmup_steps))
    loss_func = MaskedCrossEntropyLoss()
    accu_func = MaskedAccuracy()

    def epoch(model, train, current_step=0):
        start_time = datetime.now()
        if train:
            model = model.train()
            loader = dl_train
            t = 'Training:'
        else:
            model = model.eval()
            loader = dl_valid
            t = 'Validating:'
        losses = []
        accus = []
        ns = []
        chunk_time = datetime.now()
        n_seen = 0
        if train:
            n_total = len(ds_train) // args.world_size
        else:
            n_total = len(ds_valid)
        for i, batch in enumerate(loader):
            new_loss, new_accu, new_n = step(model, batch, train)
            losses.append(new_loss * new_n)
            accus.append(new_accu * new_n)
            ns.append(new_n)
            n_seen += len(batch[0])
            total_n = sum(ns)
            rloss = sum(losses) / total_n
            raccu = sum(accus) / total_n
            if train:
                nsteps = current_step + i + 1
            else:
                nsteps = i
            print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %.4f accu = %.4f'
                  % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss, raccu),
                  end='')
            if train:
                losses = losses[-999:]
                accus = accus[-999:]
                ns = ns[-999:]
                if (nsteps) % train_steps == 0 and rank == 0:
                    if not ptjob:
                        mlflow.log_metrics({'train_loss': rloss,
                                            'train_accu': raccu},
                                           step=nsteps)
                    print('\nTraining complete in ' + str(datetime.now() - chunk_time))
                    with torch.no_grad():
                        _ = epoch(model, False, current_step=nsteps)
                    chunk_time = datetime.now()
        if not train:
            print('\nValidation complete in ' + str(datetime.now() - start_time))
            if not ptjob:
                mlflow.log_metrics({'valid_loss': rloss,
                                    'valid_accu': raccu},
                                   step=current_step)
            return rloss
        elif rank == 0:
            print('\nEpoch complete in ' + str(datetime.now() - start_time))
        return i

    def step(model, batch, train):
        if config['task'] == 'mlm':
            src, tgt, mask, nodes, edges, connections, edge_mask = batch
        else:
            src, nodes, edges, connections, edge_mask = batch
            tgt = src.detach().clone()
            mask = (src != PROTEIN_ALPHABET.index(PAD)).float()
        src = src.to(device)
        tgt = tgt.to(device)
        mask = mask.to(device)
        nodes = nodes.cuda()
        edges = edges.cuda()
        connections = connections.cuda()
        edge_mask = edge_mask.cuda()
        outputs = model(nodes, edges, connections, src, edge_mask)
        loss = loss_func(outputs, tgt, mask)
        accu = accu_func(outputs, tgt, mask)
        if train:
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            scheduler.step()
        return loss.item(), accu.item(), mask.sum().item()

    total_steps = 0
    if rank == 0:
        if not ptjob:
            mlflow.set_experiment(config['experiment'])
            mlflow.log_params(config)
    n_parameters = sum(p.numel() for p in model.parameters())
    if rank == 0:
        print('%d model parameters' %n_parameters)
        print('%d training sequences' %len(len_train))
        print('%d validation sequences' %len(len_valid))
    best_valid_loss = 100
    patience = 20
    min_epochs = 500
    waiting = 0
    m_file = args.out_fpath + 'metrics.csv'
    for e in range(epochs):
        train_sortish_sampler.set_epoch(e)
        total_steps += epoch(model, True, current_step=total_steps)
        if rank == 0:
            nsteps = total_steps
            model_path = args.out_fpath + 'checkpoint%d.tar' % nsteps
            torch.save({
                'step': nsteps,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, model_path)
            with torch.no_grad():
                vloss = epoch(model, False, current_step=total_steps)
                with open(m_file, 'a') as f:
                    f.write(str(nsteps))
                    f.write(',')
                    f.write(str(vloss))
                    f.write('\n')
                if vloss < best_valid_loss:
                    best_valid_loss = vloss
                    waiting = 0
                    best_path = model_path
                else:
                    waiting += 1
            if waiting >= patience and e > min_epochs:
                break
    if rank == 0 and dl_test is not None:
        print('Loading %s and testing...' %best_path)
        sd = torch.load(best_path)
        sd = sd['model_state_dict']
        model.load_state_dict(sd)
        model = model.eval()
        with torch.no_grad():
            dl_valid = dl_test
            _ = epoch(model, False)


if __name__ == '__main__':
    main()
