import argparse
import os
import sys

from model.models import Backbone, CLS
from model.losses import contrastive_loss, contrastive_loss_single
from datautils import load_physionet_data, load_data
from utils import obtain_classifier_result, obtain_impute_result, mse_with_mask, subsample_timepoints
from correlation import calc_corr_penalty

import torch
import random
import numpy as np


def set_random_seed(seed=0):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    if torch.cuda.is_available():
        return 'cuda:0'
    else:
        return 'cpu'


def get_metrics(dataset: str = 'P12'):
    if dataset.lower() == 'physionet':
        metrics = ['auc', 'acc']
    elif dataset.lower() == 'pam':
        metrics = ['acc', 'precision', 'recall', 'f1']
    else:
        metrics = ['auc', 'auprc', 'acc']
    return metrics


def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--per-show', type=int, default=1)
    parser.add_argument('--save', type=str, default='save')
    parser.add_argument('--dataset', type=str, default='P12')

    parser.add_argument('--n', type=int, default=8000)
    parser.add_argument('--q', type=float, default=0.016,
                        help="Quantization on the physionet dataset.")
    parser.add_argument('--classif', action='store_true',
                        help="Include binary classification loss")
    parser.add_argument('--n_split', type=int, default=1,
                        help="Load split[1-5] data. Only for non-phy.")

    # interpolation
    parser.add_argument('--sample-tp', type=float, default=1.0)

    parser.add_argument('--embed-time', type=int, default=128)
    parser.add_argument('--num-ref-points', type=int, default=128)
    parser.add_argument('--mask_ratio', type=float, default=0.1)
    parser.add_argument('--weight', type=float, default=0.6)
    parser.add_argument('--hidden', type=int, default=32)
    parser.add_argument('--repr-hidden', type=int, default=32)
    parser.add_argument('--cls-hidden', type=int, default=128)

    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--head-lr', type=float, default=None)
    parser.add_argument('--max-epoch', type=int, default=200)

    parser.add_argument('--alpha', type=float, default=1000)
    parser.add_argument('--beta', type=float, default=10.)
    parser.add_argument('--gamma', type=float, default=10.)

    parser.add_argument('--pooling-mode', type=str, choices=['auto', 'manner'], default='auto')
    parser.add_argument('--pooling-list', type=str, default='8,16,32,48,64,96')
    parser.add_argument('--pooling-lower', type=float, default=4)
    parser.add_argument('--pooling-upper', type=float, default=1.)

    # missing
    parser.add_argument('--missing-ratio', type=float, default=0)
    parser.add_argument('--feature-removal-level', type=str, default='set')

    return parser.parse_args()


if __name__ == '__main__':
    args = args_parser()
    set_random_seed(args.seed)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.head_lr is None:
        args.head_lr = args.lr

    if args.pooling_mode == 'manner':
        args.pooling_list = [int(item) for item in args.pooling_list.split(',')]

    data_obj = load_data(args.dataset, 'cpu', args, args.q, args.n_split)

    windows = data_obj['windows'].detach().numpy().tolist()
    length = data_obj['ts_length'].detach().numpy().tolist()
    channels = data_obj['input_dim']
    timestamp = data_obj['timestamp']
    device = get_device()

    train_loader = data_obj["train_dataloader"]
    test_loader = data_obj["test_dataloader"]
    if args.classif:
        val_loader = data_obj["val_dataloader"]

    if args.classif:
        metrics = get_metrics(args.dataset)
    else:
        metrics = ['mse']

    nameformat = '%s_%s_%s' % (args.dataset, 'classification' if args.classif else 'forecast',
                         'q' + str(args.q) if args.dataset.lower() in ['physionet'] else 'n' + str(args.n_split))
    corr = calc_corr_penalty(train_loader, nameformat, './corr', is_tuple=args.classif)
    # corr = np.eye(corr.shape[0])

    net = Backbone(timestamp, channels, latent_dim=args.repr_hidden, windows=windows, length=length,
                   n_hidden=args.hidden, device=device, weight=[args.weight, 1 - args.weight], corr=corr,
                   p=args.mask_ratio, query_length=args.num_ref_points, embed_times=args.embed_time).to(device)
    if args.classif:
        projection_head = CLS(args.repr_hidden * 2, args.cls_hidden, 8 if args.dataset.lower() == 'pam' else 2).to(device)
        criterion = torch.nn.CrossEntropyLoss(reduction='none')

        opt = torch.optim.AdamW([{"params": net.parameters(), 'lr': args.lr},
                                 {"params": projection_head.parameters(), 'lr': args.head_lr}])
    else:
        opt = torch.optim.AdamW([{"params": net.parameters(), 'lr': args.lr}])

    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.max_epoch)

    # recon = torch.nn.MSELoss()
    recon = mse_with_mask
    avg = torch.nn.AvgPool1d(2)

    best_result, test_info = 0 if args.classif else np.inf, 0
    for epoch in range(args.max_epoch):
        for niter, train_batch in enumerate(train_loader):
            if args.classif:
                train, label, dec_tp = train_batch[0].to(device), train_batch[1].to(device), None
            else:
                train_batch = train_batch.to(device)
                observed_data = train_batch[:, :, :channels]
                observed_mask = train_batch[:, :, channels:-1]
                observed_tp = train_batch[:, :, -1]
                if args.sample_tp and args.sample_tp < 1:
                    subsampled_data, subsampled_tp, subsampled_mask = subsample_timepoints(
                        observed_data.clone(), observed_tp.clone(), observed_mask.clone(), args.sample_tp)
                else:
                    subsampled_data, subsampled_tp, subsampled_mask = \
                        observed_data, observed_tp, observed_mask
                train = torch.cat([subsampled_data, subsampled_mask, subsampled_tp.reshape(-1, timestamp, 1)], dim=-1)
                dec_tp = observed_tp
            reprs, x_recon, x_pooling, recon_pooling = net(train[..., :-1], train[..., -1], dec_tp)

            recon_loss, num_pooling = torch.tensor([0.], device=device), len(x_pooling)
            for i in range(num_pooling):
                recon_loss += recon(x_pooling[i][..., :channels], x_recon[i], x_pooling[i][..., channels:])
            recon_loss /= num_pooling

            # adjust
            adjust_loss = torch.tensor([0.], device=device)
            for i in range(1, num_pooling):
                # adjust_loss += recon(recon_pooling[i].detach(), x_recon[i], x_pooling[i][..., channels:])
                adjust_loss += recon(recon_pooling[i], x_recon[i].detach(), x_pooling[i][..., channels:])
            adjust_loss /= num_pooling - 1

            # contrastive learning
            # first pooling, then contrastive
            con_loss = torch.tensor([0.], device=device)
            for i in range(num_pooling - 1):
                con_loss += contrastive_loss_single(reprs[i], reprs[i + 1])
            con_loss /= num_pooling - 1

            proj_loss = torch.tensor([0.], device=device)
            if args.classif:
                each_proj_loss = criterion(projection_head(reprs[0]), label)
                for y in torch.unique(label):
                    proj_loss += each_proj_loss[label == y].mean()

            loss = con_loss + args.alpha * recon_loss + args.beta * proj_loss + args.gamma * adjust_loss

            opt.zero_grad()
            loss.backward()
            opt.step()

            if niter % 10 == 0:
                print('Epoch#%d/%d, Iter#%d/%d, loss = %.4f, contrastive loss = %.4f, reconstruction loss = %.4f, adjust loss = %.4f, projection loss = %.4f' %
                      (epoch + 1, args.max_epoch, niter + 1, len(train_loader), loss.item(), con_loss.item(), recon_loss.item(), adjust_loss.item(), proj_loss.item()))
            # end if niter
        # end for niter
        sch.step()

        if epoch % args.per_show == 0:
            if args.classif:
                val_result, _ = obtain_classifier_result(net, projection_head, val_loader, device, metrics)
                test_result, test_loss = obtain_classifier_result(net, projection_head, test_loader, device, metrics)
                if val_result[metrics[0]] > best_result:
                    torch.save(net.state_dict(), os.path.join(args.save, 'backbone.pt'))
                    torch.save(projection_head.state_dict(), os.path.join(args.save, 'classifier.pt'))
                    test_info, best_result = {'loss': test_loss}, val_result[metrics[0]]
                    for met in metrics:
                        test_info[met] = test_result[met]
                s = 'Epoch#%d/%d%s Test: current valid result = %.4f, best valid result = %.4f,' % (
                        epoch + 1, args.max_epoch,
                        '' if args.missing_ratio == 0 else (' Missing[%s]' % str(args.missing_ratio)),
                        val_result[metrics[0]], best_result
                    )
                for met in metrics:
                    s += ' test %s = %.4f,' % (met, test_result[met])
                s += ' test loss = %.4f' % test_loss
                print(s)
            else:
                test_mse = obtain_impute_result(net, test_loader, device, args) * 1e3
                test_info['mse'] = test_mse
                s = 'Epoch#%d/%d Test: test mse result = %.4f' % (epoch + 1, args.max_epoch, test_mse)
                print(s)
            print()

    s = 'Best result%s:' % ('' if args.missing_ratio == 0 else ', Missing[%s]' % str(args.missing_ratio))
    for met in metrics:
        s += ' test %s = %.4f,' % (met, test_info[met])
    print(s)
