import numpy as np 

import torch
import torch.optim as optim

from neuralfaults.utils.parser import get_parser_with_args
from neuralfaults.utils.helpers import get_file_names, get_gain_model, get_gan_loss_functions, Log
from neuralfaults.utils.runner_gain import GAINRunner
from neuralfaults.utils.dataloader import get_dataloaders


parser = get_parser_with_args()
opt = parser.parse_args()

if len(opt.fail_quants) and len(opt.impute_model):
    opt.model = 'NA'
    fail_quants = opt.fail_quants.split(",")
    inp_quants = opt.inp_quants.split(",")
    fail_quants_prob = list(map(float, opt.fail_quants_prob.split(",")))

    for fq in fail_quants:
        assert fq in inp_quants

    assert len(fail_quants) <= len(inp_quants)
    assert len(fail_quants_prob) == len(fail_quants)

    for fqp in fail_quants_prob:
        assert fqp >= 0.0 and fqp <= 1.0

weight_path, log_path = get_file_names(opt)
logger = Log(log_path, 'w')

train_loader, val_loader, _ = get_dataloaders(opt)
model_g, model_d = get_gain_model(opt)
criterion_g, criterion_d = get_gan_loss_functions(opt)

optimizer_d = optim.Adam(model_d.parameters(), lr=opt.lr)
optimizer_g = optim.Adam(model_g.parameters(), lr=opt.lr)

runner = GAINRunner(opt.gpu, model_g, model_d, optimizer_g, optimizer_d, 
                criterion_g, criterion_d, 
                train_loader, val_loader, opt)

best_smape = 1000

logger.write_model(model_g)
logger.write_model(model_d)

for epoch in range(opt.epochs):
    runner.set_epoch_metrics()

    train_metrics = runner.train_model()
    val_metrics = runner.eval_model()

    print('TRAIN METRICS EPOCH ', epoch, train_metrics)
    print('EVAL METRICS EPOCH ', epoch, val_metrics)

    logger.log_train_metrics(train_metrics, epoch)
    logger.log_validation_metrics(val_metrics, epoch)

    mean_all_quant_smape = np.mean([val_metrics[k] for k in val_metrics.keys() if 'smape' in k])

    if mean_all_quant_smape <= best_smape:
        torch.save(model_g, weight_path + '.g')
        torch.save(model_d, weight_path + '.d')
        best_smape = mean_all_quant_smape

logger.close()