import argparse
import json
from datetime import datetime
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import LambdaLR
from apex import amp
from apex.optimizers import FusedAdam
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score

from sequence_models.collaters import TAPECollater, Seq2PropertyCollater
from sequence_models.datasets import TAPEDataset, CSVDataset
from sequence_models.samplers import ApproxBatchSampler
from sequence_models.constants import PROTEIN_ALPHABET, PAD, STOP
from sequence_models.convolutional import ByteNetLM, MaskedConv1d
from sequence_models.structure import Attention1d
from sequence_models.layers import PositionFeedForward
from sequence_models.losses import MaskedCrossEntropyLoss
from sequence_models.metrics import MaskedAccuracy
from sequence_models.utils import warmup
from sequence_models.flip_utils import load_flip_data

from utils import embed

import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (1048576, rlimit[1]))



class Model(nn.Module):

    def __init__(self, d_model, d_out, seq2seq=False, dropout=0.0):
        super().__init__()
        self.seq2seq = seq2seq
        self.d_model = d_model
        if seq2seq:
            d_cnn = 32
            d_in = d_model
            self.conv1 = MaskedConv1d(d_in, d_cnn, 129)
            self.conv2 = MaskedConv1d(d_in, d_cnn, 257)
            self.lstm = nn.LSTM(2 * d_cnn + d_in, 1024, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
            self.decoder = PositionFeedForward(2048, d_out)
        else:
            # self.ln = nn.LayerNorm(d_model)
            self.attention = Attention1d(d_model)
            self.activation = nn.GELU()
            self.dropout = nn.Dropout(dropout)
            self.hidden = nn.Linear(d_model, d_model)
            self.linear = nn.Linear(d_model, d_out)

    def forward(self, e, input_mask=None):
        if self.seq2seq:
            c1 = self.conv1(e, input_mask=input_mask)
            c2 = self.conv2(e, input_mask=input_mask)
            ell = input_mask.sum(dim=1).cpu().reshape(-1)
            e = torch.cat([e, c1, c2], dim=-1)
            e = pack_padded_sequence(e, ell, batch_first=True, enforce_sorted=False)
            e, _ = self.lstm(e)
            e, _ = pad_packed_sequence(e, batch_first=True)
            return self.decoder(e)
        else:
            attended = self.attention(e, input_mask=input_mask)
            hidden = self.hidden(self.activation(attended))
            return self.linear(self.dropout(self.activation(hidden)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--task', type=str, required=False)
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--learning_rate', default=None, type=float)
    parser.add_argument('--last', action='store_true')
    parser.add_argument('--seed', default=0, type=int)

    args = parser.parse_args()
    args.world_size = 1
    train(0, args)


def train(gpu, args):
    _ = torch.manual_seed(args.seed)
    rank = gpu
    torch.cuda.set_device(gpu + args.gpu)
    device = torch.device('cuda:' + str(gpu + args.gpu))

    ## Grab data
    task = args.task
    dropout = 0.0
    lr = 1e-4
    if task == 'secondary_structure':
        seq2seq = True
        d_out = 3
        loss_func = MaskedCrossEntropyLoss()
        test_name = 'cb513'
        if args.finetune:
            batch_size = 4
        else:
            batch_size = 16
        opt_level = 'O2'
        panel = 'tape'
    else:
        opt_level = 'O2'
        seq2seq = False
        if task == 'remote_homology':
            panel = 'tape'
            loss_func = nn.CrossEntropyLoss()
            d_out = 1195
            test_name = 'test_fold_holdout'
            batch_size = 16
        elif task == 'miser':
            panel = 'miser'
            lr = 1e-5
            loss_func = nn.MSELoss()
            d_out = 1
            if args.finetune:
                batch_size = 8
            else:
                batch_size = 64
        elif task in ['cdc28_binding', 'mitochondria_targeting']:
            panel = 'idr'
            loss_func = nn.BCEWithLogitsLoss()
            d_out = 1
            if args.finetune:
                batch_size = 3
            else:
                batch_size = 32
        elif task == 'rmanod':
            panel = 'rmanod'
            batch_size = 32
            d_out = 1
            loss_func = nn.MSELoss()
        elif 'cas9' in task:
            panel = 'cas9'
            opt_level = 'O1'
            if args.finetune:
                batch_size = 4
            else:
                batch_size = 16
            d_out = 1
            loss_func = nn.MSELoss()
            lr = 1e-5
        else:
            loss_func = nn.MSELoss()
            test_name = 'test'
            d_out = 1
            if task == 'fluorescence':
                panel = 'tape'
                if args.finetune:
                    batch_size = 48
                else:
                    batch_size = 128
            elif task == 'stability':
                panel = 'tape'
                if args.finetune:
                    batch_size = 128
                else:
                    batch_size = 512
            else:
                panel = 'flip'
                flip_dataset = task.split('_')[0]
                flip_split = '_'.join(task.split('_')[1:])
                if flip_dataset == 'aav':
                    if args.finetune:
                        batch_size = 12
                    else:
                        batch_size = 64
                elif flip_dataset == 'gb1':
                    lr = 1e-5
                    if 'one' in task:
                        batch_size = 5
                    elif args.finetune:
                        batch_size = 16
                    else:
                        batch_size = 512
                else:
                    batch_size = 16
    try:
        data_fpath = os.getenv('PT_DATA_DIR') + '/' + panel + '/'
    except:
        home = str(pathlib.Path.home())
        data_fpath = home + '/data/' + panel + '/'
    if panel == 'tape':
        data_fpath += task + '/'
        # grab data based on task
        ds_train = TAPEDataset(data_fpath, task, 'train')  # max_len only matters for contacts
        ds_valid = TAPEDataset(data_fpath, task, 'valid')
        ds_test = TAPEDataset(data_fpath, task, test_name)
        collate_fn = TAPECollater(PROTEIN_ALPHABET)
        num_workers = 0
    elif panel == 'rmanod':
        df = pd.read_csv(data_fpath + 'protabank.csv')
        df['sequence'] = df['Sequence']
        is_test = [d.count('+') > 3 for d in df['Description'].values]
        test_idx = np.array([i for i, t in enumerate(is_test) if t])
        train_idx = np.array([i for i, t in enumerate(is_test) if not t])
        len(test_idx)
        n_valid = 40
        np.random.seed(args.seed)
        np.random.shuffle(test_idx)
        valid_idx = test_idx[:n_valid]
        test_idx = test_idx[n_valid:]
        df.loc[train_idx, 'split'] = 'train'
        df.loc[valid_idx, 'split'] = 'valid'
        df.loc[test_idx, 'split'] = 'test'
        ds_train = CSVDataset(df=df, split='train', outputs=['Data'])
        ds_valid = CSVDataset(df=df, split='valid', outputs=['Data'])
        ds_test = CSVDataset(df=df, split='test', outputs=['Data'])
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 16
    elif panel == 'miser':
        df = pd.read_csv(data_fpath + 'processed/enrichments.csv')
        df['split'] = df['split2']
        df['sequence'] = df['gapped']
        df['sequence'] = df['sequence'].apply(lambda x: x.replace(PAD, STOP))
        ds_train = CSVDataset(df=df, split='train', outputs=['tgt'])
        ds_valid = CSVDataset(df=df, split='valid', outputs=['tgt'])
        ds_test = CSVDataset(df=df, split='test', outputs=['tgt'])
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 8
    elif panel == 'cas9':
        df = pd.read_csv(data_fpath + 'ready/choi.csv')
        df['tgt'] = df[args.task[5:]]
        df = df.dropna(subset=['tgt'])
        ds_train = CSVDataset(df=df, split='train', outputs=['tgt'])
        ds_valid = CSVDataset(df=df, split='valid', outputs=['tgt'])
        ds_test = CSVDataset(df=df, split='test', outputs=['tgt'])
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 8
    elif panel == 'idr':
        df = pd.read_csv(data_fpath + 'ready/' + task + '.csv')
        ds_train = CSVDataset(df=df, split='train', outputs=['tgt'])
        ds_valid = CSVDataset(df=df, split='valid', outputs=['tgt'])
        ds_test = CSVDataset(df=df, split='test', outputs=['tgt'])
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 8
    else:
        ds_train, ds_valid, ds_test = load_flip_data(data_fpath, flip_dataset, flip_split, max_len=2048)
        collate_fn = Seq2PropertyCollater(PROTEIN_ALPHABET, return_mask=True)
        num_workers = 2
    if task == 'remote_homology' or 'meltome' in task:
        len_train = np.array([len(d[0]) for d in ds_train])
        if args.finetune:
            max_tokens = 5000
            max_batch = 32
        else:
            max_tokens = 30000
            max_batch = 128
        max_sq = np.inf
        sampler = DistributedSampler(ds_train, num_replicas=args.world_size, rank=rank)
        train_sampler = ApproxBatchSampler(sampler, max_tokens, max_batch, len_train, max_square_tokens=max_sq)
        dl_train = DataLoader(ds_train, collate_fn=collate_fn, batch_sampler=train_sampler, num_workers=num_workers)
    else:
        sampler = DistributedSampler(ds_train, num_replicas=args.world_size, rank=rank)
        dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=collate_fn,
                              num_workers=num_workers, sampler=sampler)
    dl_valid = DataLoader(ds_valid, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
    dl_test = DataLoader(ds_test, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)

    def step(model, batch, train=True, return_values=False):
        src, tgt, input_mask = batch
        src = src.to(device)
        tgt = tgt.to(device)
        if args.finetune or 'aav' in task or 'meltome' in task:
            input_mask = (src != PROTEIN_ALPHABET.index(PAD)).float().unsqueeze(-1)
            if args.finetune:
                e = model['embedder'](src, input_mask=input_mask)
            else:
                with torch.no_grad():
                    e = model['embedder'](src, input_mask=input_mask)
            outputs = model['decoder'](e, input_mask=input_mask)
        else:
            input_mask = input_mask.to(device)
            if task == 'secondary_structure':
                input_mask = input_mask.unsqueeze(-1)
            outputs = model(src, input_mask=input_mask)
        if task == 'secondary_structure':
            mask = input_mask.bool()
            loss = loss_func(outputs, tgt, mask)
            locations = mask.sum().item()
        else:
            if task == 'remote_homology':
                tgt = tgt.view(len(outputs))
            loss = loss_func(outputs, tgt)
            locations = len(tgt)
            mask = torch.ones(1)  # dummy
        if train:
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            scheduler.step()
        if return_values:
            return loss.item(), locations, outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu()
        else:
            return loss.item(), locations

    def epoch(model, train, current_step=0):
        start_time = datetime.now()
        if train:
            model = model.train()
            loader = dle_train
            t = 'Training:'
        else:
            model = model.eval()
            loader = dle_valid
            t = 'Validating:'
        losses = []
        ns = []
        n_seen = 0
        if train:
            n_total = len(ds_train)
        else:
            n_total = len(ds_valid)
        for i, batch in enumerate(loader):
            new_loss, new_n = step(model, batch, train)
            losses.append(new_loss * new_n)
            ns.append(new_n)
            n_seen += len(batch[0])
            total_n = sum(ns)
            if total_n == 0:
                rloss = 0
            else:
                rloss = sum(losses) / total_n
            if train:
                nsteps = current_step + i + 1
            else:
                nsteps = i
            if rank == 0:
                print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f'
                      % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss),
                      end='')
        if not train:
            # print('\nValidation complete in ' + str(datetime.now() - start_time))
            return rloss
        # elif rank == 0:
        #     print('\nTraining complete in ' + str(datetime.now() - start_time))
        return i, rloss

    def test_epoch(dl):
        with torch.no_grad():
            losses = []
            ns = []
            n_seen = 0
            pred = []
            tgt = []
            masks = []
            # src = []
            for i, batch in enumerate(dl):
                new_loss, new_n, p, s, t, m = step(model, batch, False, return_values=True)
                losses.append(new_loss * new_n)
                pred.append(p)
                tgt.append(t)
                masks.append(m)
                # src.append(s)
                ns.append(new_n)
                n_seen += len(batch[0])
                total_n = sum(ns)

        test_loss = sum(losses) / total_n
        if task in ['fluorescence', 'stability', 'miser', 'rmanod'] or panel == 'flip' or panel == 'cas9':
            pred = torch.cat(pred)
            tgt = torch.cat(tgt)
            pred = pred.numpy()
            tgt = tgt.numpy()
            metric_func = spearmanr
            metric = metric_func(pred, tgt).correlation
        elif panel == 'idr':
            pred = torch.cat(pred)
            tgt = torch.cat(tgt)
            pred = pred.numpy()
            tgt = tgt.numpy()
            metric = roc_auc_score(tgt, pred)
        elif task == 'remote_homology':
            pred = torch.cat(pred)
            tgt = torch.cat(tgt)
            pred = pred.numpy()
            tgt = tgt.numpy()
            metric = (pred.argmax(axis=-1) == tgt).mean()
        elif task == 'secondary_structure':
            pred = torch.cat([a.reshape(-1, 3) for a in pred])
            tgt = torch.cat([t.reshape(-1) for t in tgt])
            metric_func = MaskedAccuracy()
            mask = (tgt != -100)
            metric = metric_func(pred, tgt, mask).item()
        print('loss: %f' %test_loss, end='\t')
        print('metric: %f' %(metric))
        # predictions = {
        #     'predictions': pred,
        #     'targets': tgt
        # }
        results = {
            'metric': metric,
            'loss': test_loss
        }
        return results


    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_embedding = config['d_embed']
    d_model = config['d_model']
    n_layers = config['n_layers']
    kernel_size = config['kernel_size']
    activation = config['activation']
    slim = config['slim']
    r = config['r']
    epochs = 500
    if 'metrics.csv' in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + 'metrics.csv')
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        skip = len(record) // 10
        idx = list(range(len(record) - 1, 0, -skip))
        record = record.iloc[idx]
        record.columns = ['loss', 'accuracy', 'tokens', 'step']
        df = pd.DataFrame(np.array([[np.nan, np.nan, 0, 0]]), columns=record.columns)
        if args.last:
            record = record.iloc[:1]
        # record = pd.concat([df, record], ignore_index=True)

        record = pd.concat([record, df], ignore_index=True)
    if args.finetune:
        loss_col = task + '_ft_loss'
        metr_col = task + '_ft_metric'
    else:
        loss_col = task + '_loss'
        metr_col = task + '_metric'
    for idx, row in tqdm(record.iterrows(), total=len(record)):
        print(record)
        wd = args.map_dir + 'checkpoint' + str(int(row['step'])) + '.tar'
        # if idx not in keep:
        #     continue
        if loss_col in row:
            if not np.isnan(row[loss_col]):
                continue
        np.random.seed(args.seed + 100)
        bytenet = ByteNetLM(n_tokens, d_embedding, d_model, n_layers, kernel_size, r, dropout=0.0, activation=activation,
                            causal=False, padding_idx=PROTEIN_ALPHABET.index(PAD), final_ln=True, slim=slim)
        if int(row['step']) != 0:
            sd = torch.load(wd, map_location=device)
            sd = sd['model_state_dict']
            sd = {k.split('module.')[1]: v for k, v in sd.items()}
            bytenet.load_state_dict(sd)
        # Renorm the embedding
        bytenet.embedder.embedder.weight = nn.Parameter(bytenet.embedder.embedder.weight * 1.25)
        embedder = bytenet.embedder.to(device)
        decoder = Model(d_model, d_out, seq2seq=seq2seq, dropout=dropout)
        decoder = decoder.to(device)
        if args.finetune:
            optimizer = FusedAdam(list(embedder.parameters()) + list(decoder.parameters()), lr=lr)
        else:
            optimizer = FusedAdam(decoder.parameters(), lr=lr)
        (embedder, decoder), optimizer = amp.initialize([embedder, decoder], optimizer, opt_level=opt_level)
        embedder = embedder
        decoder = decoder
        if args.finetune or 'aav' in task or 'meltome' in task:
            dle_test = dl_test
            dle_train = dl_train
            dle_valid = dl_valid
            model = nn.ModuleDict({'embedder': embedder, 'decoder': decoder})
        else:
            # print('Embedding test sequences...')
            dle_test = embed(embedder, dl_test, device, batch_size, seq2seq=seq2seq)
            # print('Embedding training sequences...')
            dle_train = embed(embedder, dl_train, device, batch_size, seq2seq=seq2seq)
            # print('Embedding validation sequences...')
            dle_valid = embed(embedder, dl_valid, device, batch_size, seq2seq=seq2seq)
            model = decoder
        n_warmup = 1000
        total_steps = 0
        best_valid_metric = -100.0
        best_valid_loss = 100.0
        if 'gb1_one' in task:
            patience = 20
        elif 'cas9' in task:
            patience = 20
        else:
            patience = 5
        scheduler = LambdaLR(optimizer, warmup(n_warmup))
        waiting = 0
        for e in range(epochs):
            sampler.set_epoch(e)
            ts, train_loss = epoch(model, True, current_step=total_steps)
            total_steps += ts
            nsteps = total_steps
            if rank == 0:
                results = test_epoch(dle_valid)
                vloss = results['loss']
                vmetric = results['metric']
                if  task == 'gb1_one_vs_rest':
                    vmetric = -vloss
                if vmetric > best_valid_metric and task != 'gb1_one_vs_rest':
                    best_valid_metric = vmetric
                    waiting = 0
                    if panel != 'cas9':
                        if args.out_fpath is not None and rank == 0:
                            torch.save({
                                'step': nsteps,
                                'model_state_dict': model.state_dict(),
                            }, args.out_fpath + 'best.tar')
                if vloss < best_valid_loss:
                    best_valid_loss = vloss
                    waiting = 0
                    if panel == 'cas9' or task == 'gb1_one_vs_rest':
                        if args.out_fpath is not None and rank == 0:
                            torch.save({
                                'step': nsteps,
                                'model_state_dict': model.state_dict(),
                            }, args.out_fpath + 'best.tar')
                if vloss < train_loss:
                    waiting = 0
                else:
                    waiting += 1
                if waiting == patience:
                    break
        del dle_train
        if args.out_fpath is not None and rank == 0:
            # if task != 'gb1_one_vs_rest':
            #     sd = torch.load(args.out_fpath + 'best.tar')
            #     sd = sd['model_state_dict']
            #     model.load_state_dict(sd)
            sd = torch.load(args.out_fpath + 'best.tar')
            sd = sd['model_state_dict']
            model.load_state_dict(sd)
            model = model.eval()
            results = test_epoch(dle_test)
            record.loc[idx, metr_col] = results['metric']
            record.loc[idx, loss_col] = results['loss']
            record.to_csv(args.out_fpath + 'metrics.csv', index=False)


if __name__ == '__main__':
    main()