import numpy as np 

import torch
import torch.optim as optim

from neuralfaults.utils.parser import get_parser_with_args
from neuralfaults.utils.helpers import get_file_names, get_model, get_loss_function, Log
from neuralfaults.utils.runner import Runner
from neuralfaults.utils.dataloader import get_dataloaders


parser = get_parser_with_args()
opt = parser.parse_args()

if len(opt.fail_quants) and len(opt.impute_model):
    opt.model = 'NA'
    fail_quants = opt.fail_quants.split(",")
    inp_quants = opt.inp_quants.split(",")
    fail_quants_prob = list(map(float, opt.fail_quants_prob.split(",")))

    for fq in fail_quants:
        assert fq in inp_quants

    assert len(fail_quants) <= len(inp_quants)
    assert len(fail_quants_prob) == len(fail_quants)

    for fqp in fail_quants_prob:
        assert fqp >= 0.0 and fqp <= 1.0

weight_path, log_path = get_file_names(opt)
logger = Log(log_path, 'w')

train_loader, val_loader, _ = get_dataloaders(opt)
model = get_model(opt)
criterion = get_loss_function(opt)
optimizer = optim.SGD(model.parameters(), lr=opt.lr)

runner = Runner(opt.gpu, model, optimizer, criterion,
                train_loader, val_loader, opt)

best_smape = 1000

logger.write_model(model)

for epoch in range(opt.epochs):
    runner.set_epoch_metrics()

    train_metrics = runner.train_model()
    val_metrics = runner.eval_model()

    print('TRAIN METRICS EPOCH ', epoch, train_metrics)
    print('EVAL METRICS EPOCH ', epoch, val_metrics)

    logger.log_train_metrics(train_metrics, epoch)
    logger.log_validation_metrics(val_metrics, epoch)

    mean_all_quant_smape = np.mean([val_metrics[k] for k in val_metrics.keys() if 'smape' in k])
    
    if mean_all_quant_smape <= best_smape:
        torch.save(model, weight_path)
        best_smape = mean_all_quant_smape

logger.close()