import sys
import logging
import torch
import numpy as np
import pprint
import time
import os.path as osp
from typing import Tuple

from models import (
    build_model, build_optimizer, build_scheduler, get_loss_func, get_lr
)
from utils import (
    save_checkpoint, load_checkpoint, load_train_checkpoint,
    mkdir, setup_logging, get_logname, TensorboardWriter
)
from utils.metric import AverageMeter, DiceMetric, IouMetric
from utils.parser import parse_args, load_config
from utils.lr_scheduler import Poly
from dataset import build_data_pipeline

logger = logging.getLogger(__name__)


def train_epoch(train_loader, model, loss_func, optimizer, scheduler, epoch, device,
                thres=0.5, print_freq=10, calculate_metric=False, writer=None) -> None:
    batch_time, data_time = AverageMeter(), AverageMeter()
    loss_meter = AverageMeter()
    if calculate_metric:
        train_metric = DiceMetric(thres=thres, num_classes=model.n_classes)
        #train_metric = IouMetric(num_classes=model.n_classes)
    # check if region loss is employed
    use_region_loss : bool = "region" in loss_func.name or model.branch_number > 1
    if use_region_loss:
        loss_seg_meter, loss_region_meter = AverageMeter(), AverageMeter()

    model.train()

    lr = get_lr(optimizer)
    max_iter = len(train_loader)

    end = time.time()
    logger.info("====== Start training epoch {} ======".format(epoch + 1))
    for i, samples in enumerate(train_loader):
        # compute the time for data loading
        data_time.update(time.time() - end)
        # decouple samples
        inputs, labels = samples[0].to(device), samples[1].to(device)
        if use_region_loss:
            region_sizes = samples[2].to(device)
        # forward
        outputs = model(inputs)
        if use_region_loss:
            loss, loss_seg, loss_region = loss_func(outputs, [labels, region_sizes])
            loss_seg_meter.update(loss_seg.item(), inputs.size(0))
            loss_region_meter.update(loss_region.item(), inputs.size(0))
        else:
            loss = loss_func(outputs, labels)
        loss_meter.update(loss.item(), inputs.size(0))
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if type(scheduler) == Poly:
            scheduler.step(epoch=epoch)
        # tensorboard writer
        if writer is not None:
            writer.add_scalars(
                {
                    "Train/Iter/loss": loss,
                    "Train/Iter/lr" : lr
                },
                global_step=max_iter * epoch + i
            )
            if use_region_loss:
                writer.add_scalars(
                    {
                        "Train/Iter/loss_seg": loss_seg,
                        "Train/Iter/loss_region" : loss_region,
                    },
                    global_step=max_iter * epoch + i
                )

        if calculate_metric:
            # update train_metric
            if model.branch_number == 1:
                predicts = model.act(outputs)
            else:
                predicts = model.act(outputs[0])
            score = train_metric.update(predicts.detach(), labels.detach())
            if writer is not None:
                writer.add_scalars(
                    {"Train/Iter/Score": score},
                    global_step=max_iter * epoch + i
                )

        # print training info if triggered
        if (i + 1) % print_freq == 0:
            logger.info(
                ("Train Epoch[{0}][{1}/{2}]\t"
                 "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                 "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                 "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                 "LR {lr:.3g}\t"
                 "Score {score:.4f}").format(
                     epoch + 1, i + 1, max_iter,
                     batch_time=batch_time,
                     data_time=data_time,
                     loss=loss_meter,
                     lr=lr,
                     score=score if "score" in locals().keys() else -1)
            )

    # Get some measures fo the entire epoch
    if writer is not None:
        writer.add_scalars(
            {
                "Train/Epoch/loss": loss_meter.avg,
                "Train/Epoch/lr": lr
            },
            global_step=epoch
        )
        if use_region_loss:
            writer.add_scalars(
                {
                    "Train/Epoch/loss_seg": loss_seg_meter.avg,
                    "Train/Epoch/loss_region" : loss_region_meter.avg,
                },
                global_step=epoch
            )
    if calculate_metric:
        train_score = train_metric.mean_score()

        if use_region_loss:
            logger.info(
                ("Train Epoch[{0}]\t"
                 "Samples[{1}]\t"
                 "Loss {loss.avg:.4f}\t"
                 "Loss_seg {loss_seg.avg:.4f}\t"
                 "Loss_region {loss_region.avg:.4f}\t"
                 "Score {auc:.4f}\t"
                 "LR {lr:.3g}\t"
                 "alpha {alpha:.3g}").format(
                     epoch + 1, train_metric.num_samples,
                     loss=loss_meter, loss_seg=loss_seg_meter,
                     loss_region=loss_region_meter,
                     auc=train_score, lr=lr, alpha=loss_func.alpha)
            )
            loss_func.adjust_alpha(epoch)
        else:
            logger.info(
                ("Train Epoch[{0}]\t"
                 "Samples[{1}]\t"
                 "Loss {loss.avg:.4f}\t"
                 "Score {auc:.4f}\t"
                 "LR {lr:.3g}").format(
                     epoch + 1, train_metric.num_samples,
                     loss=loss_meter, auc=train_score, lr=lr)
            )
        if writer is not None:
            writer.add_scalars(
                {"Train/Epoch/Score": train_score},
                global_step=epoch
            )
    logger.info("====== Complete training epoch {} ======".format(epoch + 1))


@torch.no_grad()
def eval_epoch(test_loader, model, loss_func, epoch, device,
               thres=0.5, print_freq=10, writer=None, phase="val") -> float:
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    eval_metric = DiceMetric(thres=thres, num_classes=model.n_classes)
    # eval_metric = IouMetric(num_classes=model.n_classes, classes=test_loader.dataset.classes)
    # check if region loss is employed
    use_region_loss : bool = "region" in loss_func.name or model.branch_number > 1
    if use_region_loss:
        loss_seg_meter, loss_region_meter = AverageMeter(), AverageMeter()

    model.eval()

    max_iter = len(test_loader)
    end = time.time()
    for i, samples in enumerate(test_loader):
        # decouple samples
        inputs, labels = samples[0].to(device), samples[1].to(device)
        if use_region_loss:
            region_sizes = samples[2].to(device)
        # forward
        outputs = model(inputs)
        if use_region_loss:
            loss, loss_seg, loss_region = loss_func(outputs, [labels, region_sizes])
            loss_seg_meter.update(loss_seg.item(), inputs.size(0))
            loss_region_meter.update(loss_region.item(), inputs.size(0))
        else:
            loss = loss_func(outputs, labels)
        loss_meter.update(loss.item(), inputs.size(0))
        # update metric
        if model.branch_number == 1:
            predicts = model.act(outputs)
        else:
            predicts = model.act(outputs[0])
        score = eval_metric.update(predicts.detach(), labels.detach())
        # tensorboard writer
        if writer is not None:
            writer.add_scalars(
                {
                    "{}/Iter/loss".format(phase): loss,
                    "{}/Iter/Score".format(phase): score
                },
                global_step=epoch * max_iter + i
            )
            if use_region_loss:
                writer.add_scalars(
                    {
                        "{}/Iter/loss_seg".format(phase): loss_seg,
                        "{}/Iter/loss_region".format(phase): loss_region,
                    },
                    global_step=max_iter * epoch + i
                )
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            logger.info(
                ("Eval Epoch[{0}][{1}/{2}]\t"
                 "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                 "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                 "Score {score:.4f}").format(
                     epoch + 1, i + 1, max_iter,
                     batch_time=batch_time,
                     loss=loss_meter,
                     score=score)
            )
    # Get some measures for the entire epoch
    eval_score = eval_metric.mean_score()
    if writer is not None:
        writer.add_scalars(
            {
                "{}/Epoch/loss".format(phase): loss_meter.avg,
                "{}/Epoch/Score".format(phase): eval_score
            },
            global_step=epoch
        )
        if use_region_loss:
            writer.add_scalars(
                {
                    "Train/Epoch/loss_seg": loss_seg_meter.avg,
                    "Train/Epoch/loss_region" : loss_region_meter.avg,
                },
                global_step=epoch
            )

        if writer.plot_class_score:
            writer.add_class_score(
                "{}/Epoch/Score by classes".format(phase),
                eval_metric.class_score(),
                global_step=epoch
            )
    if use_region_loss:
        logger.info(
            ("Eval Epoch[{0}]\t"
             "Samples[{1}]\t"
             "Loss {loss.avg:.4f}\t"
             "Loss_seg {loss_seg.avg:.4f}\t"
             "Loss_region {loss_region.avg:.4f}\t"
             "Score {score:.4f}\t").format(
                 epoch + 1, eval_metric.num_samples,
                 loss=loss_meter, loss_seg=loss_seg_meter,
                 loss_region=loss_region_meter,
                 score=eval_score)
        )
    else:
        logger.info(
            ("Eval Epoch[{0}]\tSamples [{1}]\t"
             "Loss {loss.avg:.4f}\t"
             "Score {score:.4f}").format(
                 epoch + 1, eval_metric.num_samples,
                 loss=loss_meter, score=eval_score)
        )
    # Print score for each class if it's in multi-class setting
    if eval_metric.num_classes > 1:
        logger.info("Score by classes :")
        eval_metric.print_class_score()

    return loss_meter.avg, eval_score


def train(cfg) -> Tuple[float, int]:
    # Pring config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Init tensorbaord witter
    writer = TensorboardWriter(cfg) if cfg.TENSORBOARD.ENABLE else None
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Build the model.
    device = torch.device(cfg.DEVICE)
    model = build_model(cfg)
    model.to(device)
    loss_func = get_loss_func(cfg)
    loss_func.to(device)

    # Create the train and val data loaders.
    train_loader = build_data_pipeline(cfg, "train")
    val_loader = build_data_pipeline(cfg, "val")

    # Build the optimizer.
    optimizer = build_optimizer(cfg, model)
    scheduler = build_scheduler(cfg, optimizer, train_loader)

    # Check if it is configured to resume training
    start_epoch, best_epoch, best_score = load_train_checkpoint(cfg, model, optimizer, scheduler)

    # Perform the training loop.
    logger.info("Start training ... ")
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
        train_epoch(train_loader, model, loss_func, optimizer, scheduler, epoch, device,
                    cfg.THRES, cfg.LOG_PERIOD, cfg.TRAIN.CALCULATE_METRIC, writer)

        if (epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
            val_loss, val_score = eval_epoch(
                val_loader, model, loss_func, epoch, device,
                cfg.THRES, cfg.LOG_PERIOD, writer
            )
            if cfg.SOLVER.LR_POLICY == "reduce_on_plateau":
                scheduler.step(val_loss if cfg.SOLVER.REDUCE_MODE == "min" else val_score)

        if (epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            # Always do a validation phase before saving checkpoint
            if (epoch + 1) % cfg.TRAIN.EVAL_PERIOD != 0:
                val_loss, val_score = eval_epoch(
                    val_loader, model, loss_func, epoch, device,
                    cfg.THRES, cfg.LOG_PERIOD, writer
                )

            save_checkpoint(
                osp.join(cfg.OUTPUT_DIR, "model"), model,
                optimizer, scheduler, epoch,
                last_checkpoint=True,
                best_checkpoint=(val_score > best_score) if best_score is not None else True,
                val_score=val_score
            )
            if best_score is None or val_score > best_score:
                best_score = val_score
                best_epoch = epoch

        if cfg.SOLVER.LR_POLICY not in ["reduce_on_plateau", "poly"]:
            scheduler.step()
    logger.info("Complete training !")
    logger.info(
        ("Best performance on validation subset - "
         "model epoch {}, score {:.4f}").format(best_epoch + 1, best_score)
    )

    # Perform test in the end
    if cfg.PERFORM_TEST:
        logger.info("Start testing ... ")
        best_epoch = cfg.SOLVER.MAX_EPOCH
        best_model_path = osp.join(
            cfg.OUTPUT_DIR, "model", "checkpoint_epoch_{}.pth".format(best_epoch)
        )
        load_checkpoint(best_model_path, model, device)
        test_loader = build_data_pipeline(cfg, "test")
        test_loss, test_score = eval_epoch(
            test_loader, model, loss_func, best_epoch, device,
            cfg.THRES, cfg.LOG_PERIOD, writer, "test"
        )
        logger.info("Complete testing !")
        logger.info(
            ("Best performance on test subset- "
             "model epoch {}, score {:.4f}").format(best_epoch, test_score)
        )

    if writer is not None:
        writer.close()


def main():
    args = parse_args()
    cfg = load_config(args)
    mkdir(cfg.OUTPUT_DIR)
    setup_logging(output_dir=cfg.OUTPUT_DIR, level=logging.INFO)
    logger.info("Please refer to the log  :  {}".format(get_logname(logger)))
    logger.info("Launch command:")
    logger.info(" ".join(sys.argv))
    train(cfg)
    logger.info("Mission completes !")
    logger.info("Please check log info in :  {}".format(get_logname(logger)))


if __name__ == "__main__":
    main()
