import torch
from torch import optim
from scipy.stats import kendalltau, rankdata, linregress
from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from models import *
from data import *
from secrets import token_hex


def train(div_factor, affinity_range, run_name='', checkpoint_file=''):
    val_freq = 256
    ckpt_freq = 4
    train_dataloader, test_dataloader, max_idx, max_len, _, _ = get_dataloaders(div_factor, affinity_range=affinity_range, data_file='data.pickle')
    context_embedder = Embedder(max_idx, max_len).cuda()
    # the context attention module is only really needed for the few-shot setting, but we left it in because
    # it has very little effect on one-shot performance, and it makes expanding to the few-shot setting easier
    context_attn = ContextAttn().cuda()
    query_embedder = Embedder(max_idx, max_len).cuda()
    predictor = Predictor().cuda()

    optimizer = optim.RAdam(list(context_embedder.parameters()) + list(query_embedder.parameters()) + list(predictor.parameters()) + list(context_attn.parameters()), lr=config['lr'])
    warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, 0.0001, 1, total_iters=config['warmup_steps'])
    annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['total_epochs'])
    scheduler = optim.lr_scheduler.SequentialLR(optimizer, [warmup_scheduler, annealing_scheduler], milestones=[config['warmup_steps']])
    writer = SummaryWriter(('runs/' + run_name) if run_name else ('runs/' + token_hex(4)))
    epoch = 0
    ckpt_counter = 0
    while True:
        total_loss = 0
        count = 0
        for i, (context_x, context_y, query_x, query_y) in enumerate(train_dataloader):
            context_x, context_y, query_x, query_y = context_x.cuda(), context_y.cuda(), query_x.cuda(), query_y.cuda()
            context = torch.zeros((len(affinity_range), len(context_x), config['d_model']), device='cuda')
            for j in range(len(affinity_range)):
                context[j] = context_embedder(context_x[:, j, :], context_y[:, j, :])
            context = context_attn(context)
            query = query_embedder(query_x)
            x = torch.concat((context, query), dim=1)
            loss = torch.mean((predictor(x) - query_y) ** 2)
            total_loss += loss.item()
            count += 1
            loss.backward()
            if (i + 1) % (val_freq * (config['batch_size'] // div_factor)) == 0:
                writer.add_scalar('loss/train', total_loss / count, epoch)
                context_embedder.eval()
                query_embedder.eval()
                predictor.eval()
                with torch.no_grad():
                    loss = 0
                    target_to_x = {}
                    target_to_y = {}
                    target_to_smiles = {}
                    plot_x = []
                    plot_y = []
                    for j, (context_x, context_y, query_x, query_y, seqs) in enumerate(test_dataloader):
                        context_x, context_y, query_x, query_y = context_x.cuda(), context_y.cuda(), query_x.cuda(), query_y.cuda()
                        context = torch.zeros((len(affinity_range), len(context_x), config['d_model']), device='cuda')
                        for k in range(len(affinity_range)):
                            context[k] = context_embedder(context_x[:, k, :], context_y[:, k, :])
                        context = context_attn(context)
                        query = query_embedder(query_x)
                        x = torch.concat((context, query), dim=1)
                        out = predictor(x)
                        loss += torch.mean((out - query_y) ** 2).item()
                        pred = out.cpu().numpy().flatten()
                        real = query_y.cpu().numpy().flatten()
                        plot_x.extend(pred)
                        plot_y.extend(real)
                        for k, seq in enumerate(seqs):
                            if seq not in target_to_x:
                                target_to_x[seq] = []
                                target_to_y[seq] = []
                                target_to_smiles[seq] = []
                            target_to_x[seq].append(pred[k])
                            target_to_y[seq].append(real[k])
                    writer.add_scalar('loss/test', loss / (j + 1), epoch)
                    writer.add_scalar('corr/raw', linregress(plot_x, plot_y).rvalue, epoch)
                    plot_x = np.array(plot_x).reshape(-1, 1)
                    plot_y = np.array(plot_y).reshape(-1, 1)
                    avg_r = 0
                    aucs = []
                    taus = []
                    for seq in target_to_x:
                        taus.append(kendalltau(rankdata(target_to_x[seq]), rankdata(target_to_y[seq])).correlation)
                        avg_r += linregress(target_to_x[seq], target_to_y[seq]).rvalue
                        indices = np.array(target_to_y[seq]) < 1
                        try:
                            aucs.append(roc_auc_score(indices, -np.array(target_to_x[seq])))
                        except:
                            pass
                    writer.add_scalar('tau/mean', np.mean(taus), epoch)
                    writer.add_scalar('tau/std', np.std(taus), epoch)
                    writer.add_scalar('corr/per_target', avg_r / len(target_to_x), epoch)
                    writer.add_scalar('roc-auc/median', np.median(aucs), epoch)
                    writer.add_scalar('roc-auc/mean', np.mean(aucs), epoch)
                    writer.add_scalar('roc-auc/std', np.std(aucs), epoch)
                    ckpt_counter += 1
                    if checkpoint_file and ckpt_counter == ckpt_freq:
                        torch.save((context_embedder.state_dict(), query_embedder.state_dict(), predictor.state_dict(), context_attn.state_dict()), checkpoint_file)
                        ckpt_counter = 0
                context_embedder.train()
                query_embedder.train()
                predictor.train()
            if (i + 1) % (config['batch_size'] // div_factor) == 0:
                total_loss = 0
                count = 0
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch)
                epoch += 1
                if epoch == config['total_epochs']:
                    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--context_affinity_min', type=float, default=-50)
    parser.add_argument('--context_affinity_max', type=float, default=50)
    parser.add_argument('--run_name', type=str, default='')
    parser.add_argument('--checkpoint_file', default='model.pt')
    args = parser.parse_args()
    # 32 is the size of each batch passed to the gpu, which is gradient accumulated to reach the full 1024 batch size
    train(32, [(args.context_affinity_min, args.context_affinity_max)], run_name=args.run_name, checkpoint_file=args.checkpoint_file)