import logging
import os

import torch
from tqdm import tqdm
import tools.utils as utils
from tools.utils import RunPhase
from eval_fn import evaluate

logger = logging.getLogger(__name__)


def train(params,
          model,
          trn_ckpt_dir,
          dataloader,
          valid_dl,
          optimizer,
          loss_fn,
          scheduler,
          epoch,
          best_dict,
          grad_accml_step_size,
          ):
    """
    train an epoch
    """

    len_trn_dl = len(dataloader)

    local_progress = tqdm(dataloader, total=len_trn_dl)

    optimizer.zero_grad()
    train_loss = 0
    num_sample = 0
    for i, batch in enumerate(local_progress):
        model.train()
        input_batch = batch['input'].type(torch.FloatTensor)
        label_batch = batch['label'].type(torch.FloatTensor)
        label_batch = torch.argmax(label_batch, dim=1)

        # move to GPU if available
        if params['cuda']:
            input_batch = input_batch.cuda(non_blocking=True)
            label_batch = label_batch.cuda(non_blocking=True)

        # compute model output and loss
        output_batch = model(input_batch)
        loss = loss_fn(output_batch, label_batch) / grad_accml_step_size
        local_progress.set_postfix(loss='{:05.3f}'.format(loss))
        loss.backward()
        train_loss += loss.item() * len(input_batch)
        num_sample += len(input_batch)
        # model update every 'grad_accml_step_size' mini batch
        if (i + 1) % grad_accml_step_size == 0:
            optimizer.step()
            optimizer.zero_grad()

    num_itr = i + 1
    value_dict = {}
    value_dict["train_loss"] = train_loss / num_sample
    # reset train loss
    train_loss = 0
    num_sample = 0
    best_dict = eval_and_save_best(params,
                                   model,
                                   trn_ckpt_dir,
                                   valid_dl,
                                   optimizer,
                                   loss_fn,
                                   scheduler,
                                   epoch,
                                   best_dict,
                                   num_itr,
                                   value_dict,
                                   )
    return best_dict


def eval_and_save_best(params,
                       model,
                       trn_ckpt_dir,
                       valid_dl,
                       optimizer,
                       loss_fn,
                       scheduler,
                       epoch,
                       best_dict,
                       num_itr,
                       value_dict,
                       ):
    val_loss, val_metrics, metrics_string, output, labels, names = evaluate(model,
                                                                            loss_fn,
                                                                            valid_dl,
                                                                            params,
                                                                            RunPhase.TRAIN)
    logger.info(f"evaluation result: {metrics_string}")
    # save validation log as csv
    utils.wirte_log_csv(path=trn_ckpt_dir, name="validation_log.csv", summary_dict=val_metrics, epoch=epoch, itr=num_itr, val_set_name="val")

    met = best_dict['met']
    prev_best_value = best_dict['value']
    prev_best_pth = best_dict["model_abs_pth"]
    # logger.info(f"trn_ckpt_dir: {trn_ckpt_dir}")
    this_value = val_loss
    value_dict['valid_loss'] = this_value
    # update best dict
    if this_value < prev_best_value:
        this_epoch = epoch
        this_itr = num_itr
        save_fname = f"{met}_{this_value:.5f}_epoch{this_epoch}_itr{this_itr}_pth.tar"
        best_model_pth = os.path.join(trn_ckpt_dir, save_fname)
        best_dict = {'met': met,
                     'epoch': epoch,
                     'itr': num_itr,
                     'value': this_value,
                     'model_state_dict': model.state_dict(),
                     'optimizer_state_dict': optimizer.state_dict(),
                     'model_abs_pth': best_model_pth}
        # remove earlier best model and save new best model
        if os.path.exists(prev_best_pth):
            os.remove(prev_best_pth)
        utils.save_checkpoint(best_dict, checkpoint=trn_ckpt_dir, filename=save_fname)

    scheduler.step()

    return best_dict
