import glob
from datetime import datetime
from tqdm import tqdm
from argparse import ArgumentParser
from torch.utils.data import DataLoader

from loss.semantic_seg import CrossEntropyLoss
import models.backbone
import models
from utils.modeling import freeze_layers
from utils.iabn import reinit_alpha
from utils.metrics import *
from utils.calibration import *
from datasets.labels import *
from datasets.seg_ttt import TrainTestAugDataset
torch.backends.cudnn.benchmark = True

# We set a maximum image size which can be fit on the GPU, in case the image is larger, we first downsample it
# to then upsample the prediction back to the original resolution. This is especially required for high resolution
# Mapillary images
img_max_size = [1024, 2048]


def main(opts):
    # Setup metric
    time_stamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
    iou_meter = runningScore(opts.num_classes)
    print(f"Current inference run {time_stamp} has started!")

    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Setup dataset and transforms
    test_dataset = TrainTestAugDataset(device=device,
                                       root=opts.dataset_root,
                                       only_inf=opts.only_inf,
                                       source=opts.source,
                                       crop_size=img_max_size,
                                       split=opts.dataset_split,
                                       threshold=opts.threshold,
                                       tta=opts.tta,
                                       flips=opts.flips,
                                       scales=opts.scales,
                                       greyscale=opts.greyscale)
    test_loader = DataLoader(test_dataset,
                             batch_size=opts.batch_size,
                             shuffle=False,
                             num_workers=opts.num_workers)

    # Load and setup model
    model = models.__dict__[opts.arch_type](backbone_name=opts.backbone_name,
                                            num_classes=opts.num_classes,
                                            update_source_bn=False,
                                            dropout=opts.dropout)
    model = torch.nn.DataParallel(model)

    # Pick newest checkpoints
    if os.path.exists(opts.checkpoints_root):
        checkpoint = max(glob.glob(os.path.join(opts.checkpoints_root, opts.checkpoint)), key=os.path.getctime)
        model.load_state_dict(torch.load(checkpoint, map_location=device), strict=True)
        # Reinitialize alpha if a custom alpha other than the one in the checkpoints is given
        if opts.alpha is not None:
            reinit_alpha(model, alpha=opts.alpha, device=device)
    else:
        raise ValueError(f"Checkpoints directory {opts.checkpoints_root} does not exist")

    model = model.to(device)

    # Set up TTT optimizer and Loss
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=opts.base_lr,
        momentum=opts.momentum,
        weight_decay=opts.weight_decay
    )
    criterion = CrossEntropyLoss().to(device)

    if opts.calibration:
        # Calibration meter
        cal_meter = CalibrationMeter(
            device,
            n_bins=10,
            num_classes=opts.num_classes,
            num_images=len(test_loader)
        )
    model.eval()

    # Create GradScaler for mixed precision
    if opts.mixed_precision:
        scaler = torch.cuda.amp.GradScaler()

    for test_idx, (img_test, gt_test, crop_test, crop_transforms) in enumerate(tqdm(test_loader)):
        # Put img on GPU if available
        img_test = img_test.to(device)
        if opts.only_inf:
            # Forward pass with original image
            with torch.no_grad():
                if opts.mixed_precision:
                    with torch.cuda.amp.autocast():
                        out_test = model(img=img_test)['pred']
                else:
                    out_test = model(img=img_test)['pred']
        else:
            # Reload checkpoints
            model.load_state_dict(torch.load(checkpoint, map_location=device), strict=True)
            # Reinitialize alpha if a custom alpha other than the one in the checkpoints is given
            if opts.alpha is not None:
                reinit_alpha(model, alpha=opts.alpha, device=device)

            model = model.to(device)

            # Compute augmented predictions
            crop_test_fused = []
            for crop_test_sub in crop_test:
                with torch.no_grad():
                    if opts.mixed_precision:
                        with torch.cuda.amp.autocast():
                            out_test = model(img=crop_test_sub)['pred']
                    else:
                        out_test = model(img=crop_test_sub)['pred']
                crop_test_fused.append(torch.nn.functional.softmax(out_test, dim=1))

            # Create pseudo gt from augmentations based on their softmax probabilities
            pseudo_gt = test_dataset.create_pseudo_gt(
                crop_test_fused, crop_transforms, [1, opts.num_classes, *img_test.shape[-2:]]
            )
            pseudo_gt = pseudo_gt.to(device)

            if opts.tta:
                # Use pseudo gt for evaluation
                out_test = pseudo_gt
            else:
                model.train()

                # Freeze layers if given
                freeze_layers(opts, model)

                # TTT Loop
                model = model.to(device)
                for epoch in range(opts.num_epochs):
                    if opts.mixed_precision:
                        with torch.cuda.amp.autocast():
                            out_test = model(img=img_test)['pred']
                    else:
                        out_test = model(img=img_test)['pred']
                    if opts.mixed_precision:
                        with torch.cuda.amp.autocast():
                            loss_train = criterion(out_test, pseudo_gt)
                    else:
                        loss_train = criterion(out_test, pseudo_gt)
                    optimizer.zero_grad()
                    if opts.mixed_precision:
                        scaler.scale(loss_train).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss_train.backward()
                        optimizer.step()

                # Do actual forward pass with updated model
                model.eval()
                with torch.no_grad():
                    if opts.mixed_precision:
                        with torch.cuda.amp.autocast():
                            out_test = model(img=img_test)['pred']
                    else:
                        out_test = model(img=img_test)['pred']

        # Upsample prediction to gt resolution
        out_test = torch.nn.functional.interpolate(out_test, size=gt_test.shape[-2:], mode='bilinear')

        # Update calibration meter
        if opts.calibration:
            cal_meter.calculate_bins(out_test, gt_test.to(device))

        # Add prediction
        iou_meter.update(gt_test.cpu().numpy(), torch.argmax(out_test, dim=1).cpu().numpy())

    # Save output
    score, class_iou, cm, iu = iou_meter.get_scores()
    mean_iou = score['Mean IoU :']

    # Compute ECE
    if opts.calibration:
        cal_meter.calculate_mean_over_dataset()
        print(f"ECE: {cal_meter.overall_ece}")

    print(f"Mean IoU: {mean_iou}")
    print(f"Current inference run {time_stamp} is finished!")

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset-root",
        type=str,
        default=os.path.join(os.getcwd(), "datasets", "cityscapes")
    )
    parser.add_argument(
        "--dataset-split",
        type=str,
        default="val"
    )
    parser.add_argument(
        "--source",
        type=str,
        default="gta",
        choices=["gta", "synthia"]
    )
    parser.add_argument(
        "--checkpoints-root",
        type=str,
        default=os.path.join(os.getcwd(), "checkpoints", "runs")
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None
    )
    parser.add_argument(
        "--arch-type",
        type=str,
        default="deeplab",
        choices=["deeplab", "deeplabv3plus", "hrnet18"]
    )
    parser.add_argument(
        "--backbone-name",
        type=str,
        default="resnet50",
        choices=["resnet50", "resnet101"]
    )
    parser.add_argument(
        "--num-classes",
        type=int,
        default=19,
        choices=[19, 16],
        help="Set 19 for a GTA trained model and 16 for a SYNTHIA trained model"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=8
    )
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=10
    )
    parser.add_argument(
        "--base-lr",
        type=float,
        default=0.05
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.7
    )
    parser.add_argument(
        "--tta",
        action="store_true"
    )
    parser.add_argument(
        "--only-inf",
        action="store_true"
    )
    parser.add_argument(
        '--scales',
        nargs='+',
        type=float,
        default=[0.25, 0.5, 0.75]
    )
    parser.add_argument(
        '--flips',
        action="store_true",
        help="Apply random flip to all images"
    )
    parser.add_argument(
        '--greyscale',
        action="store_true",
        help="Apply greyscaling for TTT"
    )
    parser.add_argument(
        '--calibration',
        action="store_true",
        help="Compute calibration during inference"
    )
    parser.add_argument(
        '--resnet-layers',
        nargs='+',
        type=int,
        default=[1, 2],
        help="1, 2, 3 and/or 4 which will be frozen for TTT"
    )
    parser.add_argument(
        '--hrnet-layers',
        nargs='+',
        type=int,
        default=[1, 2],
        help="1, 2 and/or 3 which will be frozen for TTT"
    )
    parser.add_argument(
        '--mixed-precision',
        action="store_true",
        help="Use mixed precision"
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=None,
        help='Only set this alpha to [0.0, 1.0] if you want to change the alpha from the checkpoint to a custom alpha'
    )
    parser.add_argument(
        "--dropout",
        action="store_true",
        help="Enable if dropout was used during training"
    )
    clargs = parser.parse_args()
    main(clargs)