import argparse
import json
from datetime import datetime
import os
import pathlib
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler, ConcatDataset
from apex import amp
from apex.optimizers import FusedAdam
import numpy as np
import pandas as pd

from sequence_models.collaters import SimpleCollater, StructureOutputCollater
from sequence_models.samplers import ApproxBatchSampler
from sequence_models.datasets import TRRDataset
from sequence_models.constants import PROTEIN_ALPHABET, PAD
from sequence_models.convolutional import ByteNetLM
from sequence_models.layers import PositionFeedForward
from sequence_models.trRosetta import trRosetta, pad_size
from sequence_models.losses import MaskedCrossEntropyLoss
from sequence_models.utils import Tokenizer

from utils import embed
from contact_utils import CAMEODataset, lr_p_at_ell, CASP13Dataset


class Model(nn.Module):

    def __init__(self, d_model, d_out=37, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        d_trr = 64
        d_hidden = 128
        self.down = PositionFeedForward(d_model, d_hidden)
        self.trr = trRosetta(d_init=d_hidden * 2, n2d_layers=32, model_id=None, decoder=False, p_dropout=dropout)
        self.conv_dist = nn.Conv2d(d_trr, d_out, kernel_size=1, stride=1, padding=pad_size(1, 1, 1))

    def forward(self, e, input_mask=None):
        _, ell, _ = e.shape
        e = self.down(e)
        left = e.unsqueeze(2).repeat(1, 1, ell, 1)
        right = e.unsqueeze(1).repeat(1, ell, 1, 1)
        features = torch.cat((left, right), -1)
        features = torch.permute(features, [0, 3, 1, 2])
        h = self.trr(features, input_mask=input_mask)
        h = 0.5 * (h + torch.transpose(h, 2, 3))
        logits_dist = self.conv_dist(h)
        return logits_dist



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_fpath', type=str, help='file path to config json for model')
    parser.add_argument('--map_dir', type=str, default=os.getenv('PT_MAP_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--out_fpath', type=str, required=False, default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
    parser.add_argument('--gpu', '-g', type=int, default=0)

    args = parser.parse_args()
    args.world_size = 1
    train(0, args)


def train(gpu, args):
    _ = torch.manual_seed(0)
    rank = gpu
    torch.cuda.set_device(gpu + args.gpu)
    device = torch.device('cuda:' + str(gpu + args.gpu))

    ## Grab data
    data_dir = os.getenv('PT_DATA_DIR')
    if data_dir is None:
        home = str(pathlib.Path.home())
        data_dir = home + '/data/'
    trr_dir = data_dir + '/trrosetta/trrosetta/'
    ds_1 = TRRDataset(trr_dir, 'train', max_len=2048, bin=True, return_msa=False, untokenize=True)
    ds_2 = TRRDataset(trr_dir, 'valid', max_len=2048, bin=True, return_msa=False, untokenize=True)
    ds_3 = TRRDataset(trr_dir, 'test', max_len=2048, bin=True, return_msa=False, untokenize=True)
    ds_train = ConcatDataset([ds_1, ds_2, ds_3])
    # ds_train = ds_2
    simple_collater = SimpleCollater(PROTEIN_ALPHABET, pad=True)
    str_collater = StructureOutputCollater(simple_collater, exp=False, dist_only=True)
    batch_size = 128
    sampler = DistributedSampler(ds_train, num_replicas=args.world_size, rank=rank, shuffle=False)
    dl_train = DataLoader(ds_train, batch_size=batch_size, collate_fn=str_collater,
                          num_workers=8, sampler=sampler)


    def step(model, batch, train=True, return_values=False):
        src, tgt, mask = batch
        src = src.to(device)
        tgt = tgt.to(device)
        tgt = (tgt < 13) & (tgt > 0)
        mask = mask.to(device).unsqueeze(1)
        outputs = model(src, input_mask=mask).squeeze(1) # ell x ell
        # print(outputs.shape, tgt.shape, mask.shape, mask.sum())
        mask = mask.squeeze(dim=1)
        ell = src.shape[1]
        ind = np.diag_indices(ell)
        mask = torch.clone(mask).detach()
        mask[:, ind[0], ind[1]] = 0
        # Make sure mask is boolean
        mask = mask.bool()
        # Select
        p = torch.masked_select(outputs, mask)
        # print(tgt.shape, mask.shape)
        t = torch.masked_select(tgt, mask)
        loss = F.binary_cross_entropy_with_logits(p, t.float())
        # loss = loss_func(torch.permute(outputs, [0, 2, 3, 1]), tgt, mask)
        if train:
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
        if return_values:
            return loss.item(), mask.sum().item(), outputs.detach().cpu(), src.detach().cpu(), tgt.detach().cpu(), mask.detach().cpu()
        else:
            return loss.item(), mask.sum().item()

    def epoch(model, train, current_step=0):
        start_time = datetime.now()
        if train:
            model = model.train()
            loader = dle_train
            t = 'Training:'
            n_total = len(ds_train)

        losses = []
        ns = []
        n_seen = 0
        for i, batch in enumerate(loader):
            new_loss, new_n = step(model, batch, train)
            losses.append(new_loss * new_n)
            ns.append(new_n)
            n_seen += len(batch[0])
            total_n = sum(ns)
            if total_n == 0:
                rloss = 0
            else:
                rloss = sum(losses) / total_n
            if train:
                nsteps = current_step + i + 1
            else:
                nsteps = i
            if rank == 0:
                print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %f'
                      % (t, e + 1, epochs, nsteps, n_seen, n_total, rloss),
                      end='')
        if not train:
            print('\nValidation complete in ' + str(datetime.now() - start_time))
            return rloss
        elif rank == 0:
            print('\nTraining complete in ' + str(datetime.now() - start_time))
        return i


    with open(args.config_fpath, 'r') as f:
        config = json.load(f)
    n_tokens = len(PROTEIN_ALPHABET)
    d_embedding = config['d_embed']
    d_model = config['d_model']
    n_layers = config['n_layers']
    kernel_size = config['kernel_size']
    activation = config['activation']
    slim = config['slim']
    r = config['r']
    epochs = 50
    if 'metrics.csv' in os.listdir(args.out_fpath):
        record = pd.read_csv(args.out_fpath + 'metrics.csv')
    else:
        record = pd.read_csv(args.map_dir + 'metrics.csv', header=None)
        skip = len(record) // 10
        idx = list(range(len(record) - 1, 0, -skip))
        record = record.iloc[idx]
        record.columns = ['loss', 'accuracy', 'tokens', 'step']
        df = pd.DataFrame(np.array([[np.nan, np.nan, 0, 0]]), columns=record.columns)
        record = pd.concat([record.iloc[:1], df, record.iloc[1:]], ignore_index=True)
    metr_cols = ['CAMEO', 'CASP13-FM']
    for record_idx, row in tqdm(record.iterrows(), total=len(record)):
        print(record)
        wd = args.map_dir + 'checkpoint' + str(int(row['step'])) + '.tar'
        if metr_cols[0] in row:
            if not np.isnan(row[metr_cols[0]]):
                continue
        np.random.seed(100)
        bytenet = ByteNetLM(n_tokens, d_embedding, d_model, n_layers, kernel_size, r, dropout=0.0, activation=activation,
                            causal=False, padding_idx=PROTEIN_ALPHABET.index(PAD), final_ln=True, slim=slim)
        if int(row['step']) != 0:
            sd = torch.load(wd, map_location=device)
            sd = sd['model_state_dict']
            sd = {k.split('module.')[1]: v for k, v in sd.items()}
            bytenet.load_state_dict(sd)
        # Renorm the embedding
        bytenet.embedder.embedder.weight = nn.Parameter(bytenet.embedder.embedder.weight * 1.25)
        embedder = bytenet.embedder.to(device)
        decoder = Model(d_model, d_out=1, dropout=0.1)
        decoder = decoder.to(device)
        lr = 1e-3
        opt_level = 'O2'
        optimizer = FusedAdam(decoder.parameters(), lr=lr)
        (embedder, decoder), optimizer = amp.initialize([embedder, decoder], optimizer, opt_level=opt_level)
        embedder = embedder
        decoder = decoder
        ds_cameo = CAMEODataset(data_dir + '/contact-evaluate/data/cameo/')
        ds_casp13 = CASP13Dataset(data_dir + '/contact-evaluate/data/casp13/')
        tkn = Tokenizer(PROTEIN_ALPHABET)
        model = decoder
        e_cameo = []
        c_cameo = []
        i_cameo = []
        for d in ds_cameo:
            seq = d['msa'][0]
            src = torch.from_numpy(tkn.tokenize(seq)).unsqueeze(0).to(device)
            with torch.no_grad():
                e = embedder(src.to(device))
            e_cameo.append(e)
            c_cameo.append(d['contacts'])
            i_cameo.append(d['valid_sequence_indices'])
        e_casp13 = []
        c_casp13 = []
        i_casp13 = []
        for d in ds_casp13:
            seq = d['msa'][0]
            src = torch.from_numpy(tkn.tokenize(seq)).unsqueeze(0).to(device)
            with torch.no_grad():
                e = embedder(src.to(device))
            e_casp13.append(e)
            c_casp13.append(d['contacts'])
            i_casp13.append(d['valid_sequence_indices'])
        max_tokens = 10000
        max_sq_tkns = 512 ** 2 * 4
        max_batch = 16
        lengths = np.array([len(s[0]) for s in ds_train])
        max_len = 512
        lengths = np.minimum(lengths, max_len)
        batch_sampler = ApproxBatchSampler(sampler, max_tokens, max_batch, lengths, max_square_tokens=max_sq_tkns)
        dle_train = embed(embedder, dl_train, device, None, seq2seq=False, contacts=True, max_len=max_len,
                          batch_sampler=batch_sampler)
        model = model.eval()
        tps = 0
        ps = 0
        for em, contacts, idx in zip(e_casp13, c_casp13, i_casp13):
            with torch.no_grad():
                pred = model(em)
            tp, p = lr_p_at_ell(pred.squeeze(), contacts, idx)
            # tp, p = lr_p_at_ell(torch.permute(pred[0], (1, 2, 0)).to('cpu'), contacts, idx)
            tps += tp
            ps += p
        precision = tps / ps
        # print(tps, ps)
        # print(precision)
        total_steps = 0
        for e in range(epochs):
            dle_train.batch_sampler.sampler.set_epoch(e)
            # sampler.set_epoch(e)
            total_steps += epoch(model, True, current_step=total_steps)
            nsteps = total_steps
            model = model.eval()
            tps = 0
            ps = 0
            for em, contacts, idx in zip(e_cameo, c_cameo, i_cameo):
                with torch.no_grad():
                    pred = model(em)
                tp, p = lr_p_at_ell(pred.squeeze(), contacts, idx)
                tps += tp
                ps += p
            precision_cameo = tps / ps
            print(precision_cameo)
            tps = 0
            ps = 0
            for em, contacts, idx in zip(e_casp13, c_casp13, i_casp13):
                with torch.no_grad():
                    pred = model(em)
                tp, p = lr_p_at_ell(pred.squeeze(), contacts, idx)
                tps += tp
                ps += p
            precision_casp13 = tps / ps
            print(precision_casp13)

        del dle_train
        if args.out_fpath is not None and rank == 0:
            torch.save({
                'model_state_dict': model.state_dict(),
            }, args.out_fpath + str(int(row['step'])) + '.tar')
            record.loc[record_idx, 'cameo'] = precision_cameo
            record.loc[record_idx, 'casp13-fm'] = precision_casp13
            record.to_csv(args.out_fpath + 'metrics.csv', index=False)


if __name__ == '__main__':
    main()