import argparse
import copy
import os
import sys
import time

from model.models import Backbone, dec_mtan_rnn
from model.losses import contrastive_loss, contrastive_loss_single

from utils.gratif import tsdm_collate
from utils.utils import mse_with_mask, mae_with_mask, rmse_with_mask, mse_with_mask_torch, obtain_forecast_task
from utils.datautils import obtain_window_length
from utils.correlation import calc_corr_penalty

import torch
import random
import numpy as np


def set_random_seed(seed=0):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    if torch.cuda.is_available():
        return 'cuda:0'
    else:
        return 'cpu'


def get_metrics(dataset: str = 'P12'):
    if dataset.lower() == 'physionet':
        metrics = ['auc', 'acc']
    elif dataset.lower() == 'pam':
        metrics = ['acc', 'precision', 'recall', 'f1']
    else:
        metrics = ['auc', 'auprc', 'acc']
    return metrics


def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--per-show', type=int, default=1)
    parser.add_argument('--save', type=str, default='save')
    parser.add_argument('--dataset', type=str, default='physionet2012')

    # forecasting
    parser.add_argument('--cond_time', type=int, default=36)
    parser.add_argument('--forc_time', type=int, default=0)
    parser.add_argument('--nfolds', type=int, default=5)
    parser.add_argument('--fold', type=int, default=2)

    # interpolation
    parser.add_argument('--sample-tp', type=float, default=1.0)
    parser.add_argument('--static-hidden', type=int, default=32)
    parser.add_argument('--static', action='store_true',
                        help='Use static feature for P12/P19/PAM')

    parser.add_argument('--embed-time', type=int, default=128)
    # num-reference-points
    parser.add_argument('--nrp', type=int, default=128)
    parser.add_argument('--mask_ratio', type=float, default=0.1)
    parser.add_argument('--weight', type=float, default=0.6)
    parser.add_argument('--hidden', type=int, default=32)
    parser.add_argument('--repr-hidden', type=int, default=32)
    parser.add_argument('--cls-hidden', type=int, default=128)

    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--head-lr', type=float, default=None)
    parser.add_argument('--max-epoch', type=int, default=200)

    parser.add_argument('--alpha', type=float, default=1000)
    parser.add_argument('--beta', type=float, default=10.)
    parser.add_argument('--gamma', type=float, default=10.)

    parser.add_argument('--pooling-mode', type=str, choices=['auto', 'manner'], default='auto')
    parser.add_argument('--pooling-list', type=str, default='8,16,32,48,64,96')
    parser.add_argument('--pooling-lower', type=float, default=4)
    parser.add_argument('--pooling-upper', type=float, default=1.)
    parser.add_argument('--pla', type=float, default=0., help='is pooling lower bound auto?')

    return parser.parse_args()


if __name__ == '__main__':
    args = args_parser()
    set_random_seed(args.seed)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.head_lr is None:
        args.head_lr = args.lr

    if args.pooling_mode == 'manner':
        args.pooling_list = [int(item) for item in args.pooling_list.split(',')]

    if args.dataset == 'ushcn':
        from tsdm.tasks import USHCN_DeBrouwer2019
        dataset = USHCN_DeBrouwer2019(normalize_time=True, condition_time=args.cond_time, forecast_horizon=args.forc_time, num_folds=args.nfolds)
    elif args.dataset == 'physionet2012':
        from tsdm.tasks.physionet2012 import Physionet2012
        dataset = Physionet2012(normalize_time=True, condition_time=args.cond_time, forecast_horizon=args.forc_time, num_folds=args.nfolds)
    elif args.dataset == "mimiciii":
        from tsdm.tasks.mimic_iii_debrouwer2019 import MIMIC_III_DeBrouwer2019
        dataset = MIMIC_III_DeBrouwer2019(normalize_time=True, condition_time=args.cond_time, forecast_horizon=args.forc_time, num_folds=args.nfolds)
    elif args.dataset == "mimiciv":
        from tsdm.tasks.mimic_iv_bilos2021 import MIMIC_IV_Bilos2021
        dataset = MIMIC_IV_Bilos2021(normalize_time=True, condition_time=args.cond_time, forecast_horizon=args.forc_time, num_folds=args.nfolds)

    dloader_config_train = {
        "batch_size": args.batch_size,
        "shuffle": True,
        "drop_last": True,
        "pin_memory": True,
        "num_workers": 4,
        "collate_fn": tsdm_collate,
    }

    dloader_config_infer = {
        "batch_size": 64,
        "shuffle": False,
        "drop_last": False,
        "pin_memory": True,
        "num_workers": 0,
        "collate_fn": tsdm_collate,
    }

    train_loader = dataset.get_dataloader((args.fold, 'train'), **dloader_config_train)
    valid_loader = dataset.get_dataloader((args.fold, 'valid'), **dloader_config_infer)
    test_loader = dataset.get_dataloader((args.fold, 'test'), **dloader_config_infer)
    corr_loader = dataset.get_dataloader((args.fold, 'train'), **dloader_config_infer)

    data_obj = obtain_window_length(corr_loader, args)
    windows = data_obj['windows'].detach().numpy().tolist()
    length = data_obj['ts_length'].detach().numpy().tolist()
    channels = data_obj['input_dim']
    timestamp = data_obj['timestamp']
    device = get_device()

    nameformat = '%s_%s_ct%d_nf%d_f%d' % (args.dataset, 'forecast', args.cond_time, args.nfolds, args.fold)
    corr = calc_corr_penalty(train_loader, nameformat, './corr')

    net = Backbone(timestamp, channels, latent_dim=args.repr_hidden, windows=windows, length=length,
                   n_hidden=args.hidden, device=device, weight=[args.weight, 1 - args.weight], corr=corr,
                   p=args.mask_ratio, query_length=args.nrp, embed_times=args.embed_time).to(device)
    pred_head = dec_mtan_rnn(channels, torch.linspace(0, 1, args.nrp), net.repr_dim, args.hidden, args.embed_time, learn_emb=True, device=device).to(device)

    opt = torch.optim.AdamW([
        {'params': net.parameters(), 'lr': args.lr},
        {'params': pred_head.parameters(), 'lr': args.lr}
    ])
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.max_epoch)

    # recon = torch.nn.MSELoss()
    recon = mse_with_mask_torch
    avg = torch.nn.AvgPool1d(2)

    ref_metric = 'mse'
    metrics = {
        'rmse': rmse_with_mask,
        'mae': mae_with_mask,
        'mse': mse_with_mask,
    }

    # best_result, test_info = 0 if args.classif else np.inf, 0
    best_result, test_info, time_records = np.inf, None, []

    for epoch in range(args.max_epoch):
        for niter, train_batch in enumerate(train_loader):
            begin_time = time.time()

            x_time, x, x_mask, y_time, y, y_mask = (item.to(device) for item in train_batch)
            reprs, x_recon, x_pooling, recon_pooling, repr_before_gru = net(torch.cat([x, x_mask], dim=-1), x_time, None)

            recon_loss, num_pooling = torch.tensor([0.], device=device), len(x_pooling)
            for i in range(num_pooling):
                recon_loss += recon(x_pooling[i][..., :channels], x_recon[i], x_pooling[i][..., channels:])
            recon_loss /= num_pooling

            # adjust
            adjust_loss = torch.tensor([0.], device=device)
            for i in range(1, num_pooling):
                adjust_loss += recon(recon_pooling[i], x_recon[i].detach(), x_pooling[i][..., channels:])
            adjust_loss /= num_pooling - 1

            # contrastive learning
            con_loss = torch.tensor([0.], device=device)
            for i in range(num_pooling - 1):
                con_loss += contrastive_loss_single(reprs[i], reprs[i + 1])
            con_loss /= num_pooling - 1

            pred = pred_head(repr_before_gru, y_time)
            proj_loss = recon(y, pred, y_mask)

            loss = con_loss + args.alpha * recon_loss + args.beta * proj_loss + args.gamma * adjust_loss

            opt.zero_grad()
            loss.backward()
            opt.step()

            time_records.append(time.time() - begin_time)

            if niter % 10 == 0:
                print('Epoch#%d/%d, Iter#%d/%d, loss = %.4f, contrastive loss = %.4f, reconstruction loss = %.4f, adjust loss = %.4f, projection loss = %.4f, step time = %.4f' %
                      (epoch + 1, args.max_epoch, niter + 1, len(train_loader), loss.item(), con_loss.item(), recon_loss.item(), adjust_loss.item(), proj_loss.item(), np.mean(time_records)))
                time_records = []
            # end if niter
        # end for niter
        sch.step()

        if epoch % args.per_show == 0:
            valid_results = obtain_forecast_task(valid_loader, net, pred_head, metrics, device)
            test_results = obtain_forecast_task(test_loader, net, pred_head, metrics, device)

            if best_result > valid_results[ref_metric]:
                torch.save(net.state_dict(), os.path.join(args.save, 'backbone.pt'))
                torch.save(pred_head.state_dict(), os.path.join(args.save, 'forecast.pt'))
                best_result = valid_results[ref_metric]
                test_info = copy.deepcopy(test_results)

            s = 'Epoch#%d/%d Test: current valid result = %.4f, best valid result = %.4f' % (epoch + 1, args.max_epoch, valid_results[ref_metric], best_result)
            for metric, result in test_results.items():
                s += ', test %s = %.4f' % (metric, result)
            s += '\n'
            print(s)

    s = 'Best result:'
    for metric, result in test_info.items():
        s += ' test %s = %.4f,' % (metric, result)
    print(s)