import sys
import logging
import torch
import numpy as np
import pprint
import time
import os.path as osp
from typing import Tuple

from models import (
    build_model, build_optimizer, build_scheduler, get_loss_func, get_lr
)
from utils import (
    save_checkpoint, load_checkpoint, load_train_checkpoint,
    mkdir, setup_logging, TensorboardWriter
)
from utils.metric import AverageMeter, DiceMetric, IouMetric
from utils.parser import parse_args, load_config
from utils.lr_scheduler import Poly
from dataset import build_data_pipeline

logger = logging.getLogger(__name__)


def train_epoch(train_loader, model, loss_func, optimizer, scheduler, epoch, device,
                print_freq=10, calculate_metric=False, writer=None) -> None:
    batch_time, data_time = AverageMeter(), AverageMeter()
    loss_meter = AverageMeter()
    if calculate_metric:
        # train_metric = DiceMetric(thres=thres, num_classes=model.n_classes)
        train_metric = IouMetric(num_classes=model.n_classes, ignore_index=255)
    # check if region loss is employed
    use_region_loss : bool = "region" in loss_func.name or model.branch_number > 1
    if use_region_loss:
        loss_seg_meter, loss_region_meter = AverageMeter(), AverageMeter()

    model.train()

    max_iter = len(train_loader)

    end = time.time()
    logger.info("====== Start training epoch {} ======".format(epoch + 1))
    for i, samples in enumerate(train_loader):
        # compute the time for data loading
        data_time.update(time.time() - end)
        # decouple samples
        inputs, labels = samples[0].to(device), samples[1].to(device)
        # forward
        outputs = model(inputs)
        if use_region_loss:
            loss, loss_seg, loss_region = loss_func(outputs, labels)
            loss_seg_meter.update(loss_seg.item(), inputs.size(0))
            loss_region_meter.update(loss_region.item(), inputs.size(0))
        else:
            loss = loss_func(outputs, labels)
        loss_meter.update(loss.item(), inputs.size(0))
        # backward
        optimizer.zero_grad()
        loss.backward()
        # for name, param in model.segmentation_head.named_parameters():
        #     print(name, param.grad.min(), param.grad.max())

        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if type(scheduler) == Poly:
            scheduler.step(epoch=epoch)
        # tensorboard writer
        if writer is not None:
            writer.add_scalars(
                {
                    "Train/Iter/loss": loss,
                    "Train/Iter/lr" : get_lr(optimizer)
                },
                global_step=max_iter * epoch + i
            )
            if use_region_loss:
                writer.add_scalars(
                    {
                        "Train/Iter/loss_seg": loss_seg,
                        "Train/Iter/loss_region" : loss_region,
                    },
                    global_step=max_iter * epoch + i
                )

        if calculate_metric:
            # update train_metric
            if model.branch_number == 1:
                predicts = model.act(outputs)
            else:
                predicts = model.act(outputs[0])
            iou, acc = train_metric.update(predicts.detach().cpu().numpy(),
                                           labels.detach().cpu().numpy())
            if writer is not None:
                writer.add_scalars(
                    {
                        "Train/Iter/Acc": acc,
                        "Train/Iter/Iou": iou
                    },
                    global_step=max_iter * epoch + i
                )

        # print training info if triggered
        if (i + 1) % print_freq == 0:
            log_str = (
                "Train Epoch[{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "LR {lr:.3g}\t"
                "Acc {acc:.4f}\tIou {iou:.4f}\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})").format(
                    epoch + 1, i + 1, max_iter,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=loss_meter,
                    lr=get_lr(optimizer),
                    acc=acc if "acc" in locals().keys() else -1,
                    iou=iou if "iou" in locals().keys() else -1
            )
            if use_region_loss:
                log_str += (
                    "Loss_seg {loss_seg.val:.4f} ({loss_seg.avg:.4f})\t"
                    "Loss_region {loss_region.val:.4f} ({loss_region.avg:.4f})\t").format(
                        loss_seg=loss_seg_meter, loss_region=loss_region_meter
                )
            logger.info(log_str)

    # Get some measures fo the entire epoch
    if writer is not None:
        writer.add_scalars(
            {
                "Train/Epoch/loss": loss_meter.avg,
                "Train/Epoch/lr": get_lr(optimizer)
            },
            global_step=epoch
        )
        if use_region_loss:
            writer.add_scalars(
                {
                    "Train/Epoch/loss_seg": loss_seg_meter.avg,
                    "Train/Epoch/loss_region" : loss_region_meter.avg,
                },
                global_step=epoch
            )
    if calculate_metric:
        acc, iou = train_metric.mean_score()

        if use_region_loss:
            logger.info(
                ("Train Epoch[{0}]\t"
                 "Samples[{1}]\t"
                 "Loss {loss.avg:.4f}\t"
                 "Loss_seg {loss_seg.avg:.4f}\t"
                 "Loss_region {loss_region.avg:.4f}\t"
                 "Acc {acc:.4f}\tIou {iou:.4f}"
                 "LR {lr:.3g}\t"
                 "alpha {alpha:.3g}").format(
                     epoch + 1, train_metric.num_samples,
                     loss=loss_meter, loss_seg=loss_seg_meter,
                     loss_region=loss_region_meter,
                     acc=acc, iou=iou, lr=get_lr(optimizer), alpha=loss_func.alpha)
            )
            loss_func.adjust_alpha(epoch)
        else:
            logger.info(
                ("Train Epoch[{0}]\t"
                 "Samples[{1}]\t"
                 "Loss {loss.avg:.4f}\t"
                 "Acc {acc:.4f}\tIou {iou:.4f}"
                 "LR {lr:.3g}").format(
                     epoch + 1, train_metric.num_samples,
                     loss=loss_meter, acc=acc, iou=iou, lr=get_lr(optimizer))
            )
        if writer is not None:
            writer.add_scalars(
                {
                    "Train/Epoch/Acc": acc,
                    "Train/Epoch/Iou": iou
                },
                global_step=epoch
            )
    logger.info("====== Complete training epoch {} ======".format(epoch + 1))


@torch.no_grad()
def eval_epoch(test_loader, model, loss_func, epoch, device,
               print_freq=10, writer=None, phase="val") -> float:
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    # eval_metric = DiceMetric(thres=thres, num_classes=model.n_classes)
    eval_metric = IouMetric(num_classes=model.n_classes,
                            classes=test_loader.dataset.classes,
                            ignore_index=255)
    # check if region loss is employed
    use_region_loss : bool = "region" in loss_func.name or model.branch_number > 1
    if use_region_loss:
        loss_seg_meter, loss_region_meter = AverageMeter(), AverageMeter()

    model.eval()

    max_iter = len(test_loader)
    end = time.time()
    for i, samples in enumerate(test_loader):
        # decouple samples
        inputs, labels = samples[0].to(device), samples[1].to(device)
        # forward
        outputs = model(inputs)
        if use_region_loss:
            loss, loss_seg, loss_region = loss_func(outputs, labels)
            loss_seg_meter.update(loss_seg.item(), inputs.size(0))
            loss_region_meter.update(loss_region.item(), inputs.size(0))
        else:
            loss = loss_func(outputs, labels)
        loss_meter.update(loss.item(), inputs.size(0))
        # update metric
        if model.branch_number == 1:
            predicts = model.act(outputs)
        else:
            predicts = model.act(outputs[0])
        iou, acc = eval_metric.update(predicts.detach().cpu().numpy(),
                                      labels.detach().cpu().numpy())
        # tensorboard writer
        if writer is not None:
            writer.add_scalars(
                {
                    "{}/Iter/loss".format(phase): loss,
                    "{}/Iter/Acc".format(phase): acc,
                    "{}/Iter/Iou".format(phase): iou,
                },
                global_step=epoch * max_iter + i
            )
            if use_region_loss:
                writer.add_scalars(
                    {
                        "{}/Iter/loss_seg".format(phase): loss_seg,
                        "{}/Iter/loss_region".format(phase): loss_region,
                    },
                    global_step=max_iter * epoch + i
                )
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            log_str = (
                "Eval Epoch[{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Acc {acc:.4f}\tIou {iou:.4f}\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})").format(
                    epoch + 1, i + 1, max_iter,
                    batch_time=batch_time,
                    loss=loss_meter,
                    acc=acc if "acc" in locals().keys() else -1,
                    iou=iou if "iou" in locals().keys() else -1
            )
            if use_region_loss:
                log_str += (
                    "Loss_seg {loss_seg.val:.4f} ({loss_seg.avg:.4f})\t"
                    "Loss_region {loss_region.val:.4f} ({loss_region.avg:.4f})\t").format(
                        loss_seg=loss_seg_meter, loss_region=loss_region_meter)
            logger.info(log_str)
            # logger.info(
            #     ("Eval Epoch[{0}][{1}/{2}]\t"
            #      "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
            #      "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
            #      "Acc {acc:.4f}\tIou {iou:.4f}").format(
            #          epoch + 1, i + 1, max_iter,
            #          batch_time=batch_time,
            #          loss=loss_meter,
            #          acc=acc, iou=iou)
            # )
    # Get some measures for the entire epoch
    acc, iou = eval_metric.mean_score()
    if writer is not None:
        writer.add_scalars(
            {
                "{}/Epoch/loss".format(phase): loss_meter.avg,
                "{}/Epoch/Acc".format(phase): acc,
                "{}/Epoch/Iou".format(phase): iou,
            },
            global_step=epoch
        )
        if use_region_loss:
            writer.add_scalars(
                {
                    "Train/Epoch/loss_seg": loss_seg_meter.avg,
                    "Train/Epoch/loss_region" : loss_region_meter.avg,
                },
                global_step=epoch
            )

        if writer.plot_class_score:
            writer.add_class_score(
                "{}/Epoch/Score by classes".format(phase),
                eval_metric.class_score(),
                global_step=epoch
            )
    if use_region_loss:
        logger.info(
            ("Eval Epoch[{0}]\t"
             "Samples[{1}]\t"
             "Loss {loss.avg:.4f}\t"
             "Loss_seg {loss_seg.avg:.4f}\t"
             "Loss_region {loss_region.avg:.4f}\t"
             "Acc {acc:.4f}\tIou {iou:.4f}").format(
                 epoch + 1, eval_metric.num_samples,
                 loss=loss_meter, loss_seg=loss_seg_meter,
                 loss_region=loss_region_meter,
                 acc=acc, iou=iou)
        )
    else:
        logger.info(
            ("Eval Epoch[{0}]\tSamples [{1}]\t"
             "Loss {loss.avg:.4f}\t"
             "Acc {acc:.4f}\tIou {iou:.4f}").format(
                 epoch + 1, eval_metric.num_samples,
                 loss=loss_meter, acc=acc, iou=iou)
        )
    # Print score for each class if it's in multi-class setting
    if eval_metric.num_classes > 1:
        logger.info("Score by classes :")
        eval_metric.print_class_score()

    return loss_meter.avg, iou


def train(cfg) -> Tuple[float, int]:
    # Pring config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Init tensorbaord witter
    writer = TensorboardWriter(cfg) if cfg.TENSORBOARD.ENABLE else None
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Build the model.
    device = torch.device(cfg.DEVICE)
    model = build_model(cfg)
    model.to(device)
    loss_func = get_loss_func(cfg)
    loss_func.to(device)

    # Create the train and val data loaders.
    train_loader = build_data_pipeline(cfg, "train")
    val_loader = build_data_pipeline(cfg, "val")

    # Build the optimizer.
    optimizer = build_optimizer(cfg, model)
    scheduler = build_scheduler(cfg, optimizer, train_loader)

    # Check if it is configured to resume training
    start_epoch, best_epoch, best_score = load_train_checkpoint(cfg, model, optimizer, scheduler)

    # Perform the training loop.
    logger.info("Start training ... ")
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
        train_epoch(train_loader, model, loss_func, optimizer, scheduler, epoch, device,
                    cfg.LOG_PERIOD, cfg.TRAIN.CALCULATE_METRIC, writer)

        if (epoch + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
            val_loss, val_score = eval_epoch(
                val_loader, model, loss_func, epoch, device, cfg.LOG_PERIOD, writer
            )
            if cfg.SOLVER.LR_POLICY == "reduce_on_plateau":
                scheduler.step(val_loss if cfg.SOLVER.REDUCE_MODE == "min" else val_score)

        if (epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            # Always do a validation phase before saving checkpoint
            if (epoch + 1) % cfg.TRAIN.EVAL_PERIOD != 0:
                val_loss, val_score = eval_epoch(
                    val_loader, model, loss_func, epoch, device,
                    cfg.LOG_PERIOD, writer
                )

            save_checkpoint(
                osp.join(cfg.OUTPUT_DIR, "model"), model,
                optimizer, scheduler, epoch,
                last_checkpoint=True,
                best_checkpoint=(val_score > best_score) if best_score is not None else True,
                val_score=val_score
            )
            if best_score is None or val_score > best_score:
                best_score = val_score
                best_epoch = epoch

        if cfg.SOLVER.LR_POLICY not in ["reduce_on_plateau", "poly"]:
            scheduler.step()
    logger.info("Complete training !")
    logger.info(
        ("Best performance on validation subset - "
         "model epoch {}, score {:.4f}").format(best_epoch + 1, best_score)
    )

    # Perform test in the end
    if cfg.PERFORM_TEST:
        logger.info("Start testing ... ")
        best_model_path = osp.join(
            cfg.OUTPUT_DIR, "model", "checkpoint_epoch_{}.pth".format(best_epoch + 1)
        )
        load_checkpoint(best_model_path, model, device)
        test_loader = build_data_pipeline(cfg, "test")
        test_loss, test_score = eval_epoch(
            test_loader, model, loss_func, best_epoch, device,
            cfg.THRES, cfg.LOG_PERIOD, writer, "test"
        )
        logger.info("Complete testing !")
        logger.info(
            ("Best performance on test subset- "
             "model epoch {}, score {:.4f}").format(best_epoch + 1, test_score)
        )

    if writer is not None:
        writer.close()


def main():
    args = parse_args()
    cfg = load_config(args)
    mkdir(cfg.OUTPUT_DIR)
    setup_logging(output_dir=cfg.OUTPUT_DIR, level=logging.INFO)
    logger.info("Launch command:")
    logger.info(" ".join(sys.argv))
    train(cfg)


if __name__ == "__main__":
    main()
