from __future__ import print_function, division

import argparse
import logging
import numpy as np
from pathlib import Path
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim

# from core.raft_stereo import RAFTStereo
from core.raft_stereo_modified4 import RAFTStereo_modified
from core.crestereo import CREStereo

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
from evaluate_stereo import *
import core.stereo_datasets as datasets
import torch.nn.functional as F

try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self):
            pass

        def scale(self, loss):
            return loss

        def unscale_(self, optimizer):
            pass

        def step(self, optimizer):
            optimizer.step()

        def update(self):
            pass


def calc_slant_loss(dxy, dxy_gt, pred, target, max_disp=700, B=1):
    scale = target.size(3) // dxy.size(3)
    scale_disp = max(1, scale)

    max_disp = max_disp / scale_disp
    pred = pred / scale_disp
    target = target / scale_disp
    target, index = F.max_pool2d(
        target, kernel_size=scale, stride=scale, return_indices=True
    )
    pred = F.max_pool2d(pred, kernel_size=scale, stride=scale)
    mask = (target < max_disp) & (target > 1e-3)
    diff = (pred - target).abs()

    def retrieve_elements_from_indices(tensor, indices):
        flattened_tensor = tensor.flatten(start_dim=2)
        output = flattened_tensor.gather(
            dim=2, index=indices.flatten(start_dim=2)
        ).view_as(indices)
        return output

    dxy_gt = retrieve_elements_from_indices(dxy_gt, index.repeat(1, 2, 1, 1))

    mask = (diff < B) & mask
    loss = (dxy - dxy_gt).abs()
    return (loss * mask).sum() / (mask.sum() + 1e-6)


# def sequence_loss(
#     flow_preds, flow_gt, valid, dxy_preds, dxy_gt, loss_gamma=0.9, max_flow=700
# ):
#     """Loss function defined over sequence of flow predictions"""

#     n_predictions = len(flow_preds)
#     assert n_predictions >= 1
#     flow_loss = 0.0

#     # exlude invalid pixels and extremely large diplacements
#     mag = torch.sum(flow_gt**2, dim=1).sqrt()

#     # exclude extremly large displacements
#     valid = ((valid >= 0.5) & (mag < max_flow)).unsqueeze(1)
#     assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
#     assert not torch.isinf(flow_gt[valid.bool()]).any()

#     for i in range(n_predictions):
#         assert (
#             not torch.isnan(flow_preds[i]).any()
#             and not torch.isinf(flow_preds[i]).any()
#         )
#         adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
#         i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)
#         slant_loss = calc_slant_loss(
#             dxy_preds[i], dxy_gt, flow_preds[i][:, 0:1], flow_gt[:, 0:1]
#         )
#         i_loss = (flow_preds[i][:, 0:1, :] - flow_gt).abs()
#         assert i_loss.shape == valid.shape, [
#             i_loss.shape,
#             valid.shape,
#             flow_gt.shape,
#             flow_preds[i].shape,
#         ]
#         flow_loss += i_weight * (i_loss[valid.bool()].mean() + slant_loss)

#     epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
#     epe = epe.view(-1)[valid.view(-1)]

#     metrics = {
#         "epe": epe.mean().item(),
#         "1px": (epe < 1).float().mean().item(),
#         "3px": (epe < 3).float().mean().item(),
#         "5px": (epe < 5).float().mean().item(),
#     }

#     return flow_loss, metrics


def sequence_loss(
    flow_preds, flow_gt, valid, dxy_preds, dxy_gt, loss_gamma=0.9, max_flow=700, start=0
):
    """Loss function defined over sequence of flow predictions"""

    n_predictions = len(flow_preds)
    assert n_predictions >= 1
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()

    # exclude extremly large displacements
    valid = ((valid >= 0.5) & (mag < max_flow)).unsqueeze(1)
    assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
    assert valid.shape == flow_preds[-1].shape, [valid.shape, flow_preds[-1].shape]
    assert not torch.isinf(flow_gt[valid.bool()]).any()

    for i in range(n_predictions):
        assert (
            not torch.isnan(flow_preds[i]).any()
            and not torch.isinf(flow_preds[i]).any()
        )
        # adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
        adjusted_loss_gamma = 0.8
        i_weight = adjusted_loss_gamma ** (n_predictions - i - 1 + start)
        i_loss = (flow_preds[i] - flow_gt).abs()
        assert i_loss.shape == valid.shape, [
            i_loss.shape,
            valid.shape,
            flow_gt.shape,
            flow_preds[i].shape,
        ]
        if dxy_gt is None:
            flow_loss += i_weight * (i_loss[valid.bool()].mean())
            epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
            epe = epe.view(-1)[valid.view(-1)]

            metrics = {
                "epe": epe.mean().item(),
                "1px": (epe < 1).float().mean().item(),
                "3px": (epe < 3).float().mean().item(),
                "5px": (epe < 5).float().mean().item(),
            }
        else:
            slant_loss = calc_slant_loss(
                dxy_preds[i], dxy_gt, flow_preds[i][:, 0:1], flow_gt[:, 0:1]
            )
            flow_loss += i_weight * (i_loss[valid.bool()].mean() + 50 * slant_loss)

            epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
            epe = epe.view(-1)[valid.view(-1)]

            metrics = {
                "epe": epe.mean().item(),
                "1px": (epe < 1).float().mean().item(),
                "3px": (epe < 3).float().mean().item(),
                "5px": (epe < 5).float().mean().item(),
            }
            # print("slant_loss",slant_loss)
            # print("flow_loss",flow_loss)

    return flow_loss, metrics


def fetch_optimizer(args, model):
    """Create the optimizer and learning rate scheduler"""
    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
    )

    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        args.lr,
        args.num_steps + 100,
        pct_start=0.01,
        cycle_momentum=False,
        anneal_strategy="linear",
    )

    return optimizer, scheduler


class Logger:
    SUM_FREQ = 100

    def __init__(self, model, scheduler):
        self.model = model
        self.scheduler = scheduler
        self.total_steps = 0
        self.running_loss = {}
        self.writer = SummaryWriter(log_dir=args.tensorboard_file_path)

    def _print_training_status(self):
        metrics_data = [
            self.running_loss[k] / Logger.SUM_FREQ
            for k in sorted(self.running_loss.keys())
        ]
        training_str = "[{:6d}, {:10.7f}] ".format(
            self.total_steps + 1, self.scheduler.get_last_lr()[0]
        )
        metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)

        # print the training status
        logging.info(
            f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
        )

        if self.writer is None:
            self.writer = SummaryWriter(log_dir="runs")

        for k in self.running_loss:
            self.writer.add_scalar(
                k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
            )
            self.running_loss[k] = 0.0

    def push(self, metrics):
        self.total_steps += 1

        for key in metrics:
            if key not in self.running_loss:
                self.running_loss[key] = 0.0

            self.running_loss[key] += metrics[key]

        if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
            self._print_training_status()
            self.running_loss = {}

    def write_dict(self, results):
        if self.writer is None:
            self.writer = SummaryWriter(log_dir="runs")

        for key in results:
            self.writer.add_scalar(key, results[key], self.total_steps)

    def close(self):
        self.writer.close()


def train(args):
    model = nn.DataParallel(RAFTStereo_modified(args))
    # model = nn.DataParallel(RAFTStereo(args))
    # model = nn.DataParallel(CREStereo(args))
    print("Parameter Count: %d" % count_parameters(model))

    train_loader = datasets.fetch_dataloader(args)
    optimizer, scheduler = fetch_optimizer(args, model)
    total_steps = 0
    global_batch_num = 0
    logger = Logger(model, scheduler)
    scaler = GradScaler(enabled=args.mixed_precision)
    if args.resume is not None:
        assert args.resume.endswith(".pth")
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(args.restore_ckpt)
        model.load_state_dict(checkpoint["model"], strict=True)
        optimizer.load_state_dict(checkpoint["optimizer"], strict=True)
        scaler.load_state_dict(checkpoint["scaler"], strict=True)
        scheduler.load_state_dict(checkpoint["scheduler"], strict=True)
        total_steps = checkpoint["epoch"] - 1
        global_batch_num = total_steps
        logging.info(f"Done loading checkpoint")
    if args.restore_ckpt is not None:
        assert args.restore_ckpt.endswith(".pth")
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(args.restore_ckpt)
        model.load_state_dict(checkpoint, strict=True)
        logging.info(f"Done loading checkpoint")

    model.cuda()
    model.train()
    model.module.freeze_bn()  # We keep BatchNorm frozen

    validation_frequency = 5000

    should_keep_training = True
    while should_keep_training:

        for i_batch, (name_list, *data_blob) in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            image1, image2, flow, valid = [x.cuda() for x in data_blob]

            assert model.training
            flow_predictions, gradient_predictions = model(
                image1,
                image2,
                iters=args.train_iters,
                patchmatch_rounds=args.patchmatch_rounds,
            )
            # flow_predictions, gradient_predictions = model(image1, image2)
            assert model.training
            loss1, _ = sequence_loss(flow_predictions[0], flow, valid, None, None)
            loss2, _ = sequence_loss(flow_predictions[1], flow, valid, None, None)
            loss3, metrics = sequence_loss(
                flow_predictions[2], flow, valid, None, None, max_flow=args.max_disp
            )
            # if metrics["epe"] > 200:
            #     print(name_list)
            #     torch.save(flow_predictions[-1], "res.pt")
            # loss1, _ = sequence_loss(
            #     flow_predictions[0], flow, valid, gradient_predictions[0], dxy
            # )
            # loss2, _ = sequence_loss(
            #     flow_predictions[1], flow, valid, gradient_predictions[1], dxy
            # )
            # loss3, metrics = sequence_loss(
            #     flow_predictions[2], flow, valid, gradient_predictions[2], dxy
            # )
            loss = loss1 + loss2 + loss3

            # loss, metrics = sequence_loss(flow_predictions, flow, valid, None, None)

            # loss, metrics = sequence_loss(
            #     flow_predictions, flow, valid, gradient_predictions, dxy
            # )
            logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
            logger.writer.add_scalar(
                f"learning_rate", optimizer.param_groups[0]["lr"], global_batch_num
            )
            global_batch_num += 1
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            logger.push(metrics)

            if total_steps % validation_frequency == validation_frequency - 1:
                to_save = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": total_steps + 1,
                    "scaler": scaler.state_dict(),
                    "args": args,
                    "scheduler,": scheduler.state_dict(),
                }
                save_path = Path("checkpoints/%d_%s.pth" % (total_steps + 1, args.name))
                logging.info(f"Saving file {save_path.absolute()}")
                torch.save(to_save, save_path)

                results = validate_things(
                    model.module,
                    iters=args.valid_iters,
                    patchmatch_rounds=args.patchmatch_rounds,
                )

                logger.write_dict(results)

                model.train()
                model.module.freeze_bn()

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break

        if len(train_loader) >= 1000:
            save_path = Path(
                "checkpoints/%d_epoch_%s.pth.gz" % (total_steps + 1, args.name)
            )
            logging.info(f"Saving file {save_path}")
            torch.save(model.state_dict(), save_path)

    print("FINISHED TRAINING")
    logger.close()
    PATH = "checkpoints/%s.pth" % args.name
    torch.save(model.state_dict(), PATH)

    return PATH


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", default="EAI-Stereo", help="name your experiment")
    parser.add_argument("--restore_ckpt", help="restore checkpoint")
    parser.add_argument("--resume", help="resume checkpoint")
    parser.add_argument(
        "--mixed_precision", action="store_true", help="use mixed precision"
    )
    parser.add_argument("--tensorboard_file_path")
    # Training parameters
    parser.add_argument(
        "--batch_size", type=int, default=8, help="batch size used during training."
    )
    parser.add_argument(
        "--train_datasets", nargs="+", default=["sceneflow"], help="training datasets."
    )
    parser.add_argument("--lr", type=float, default=0.0002, help="max learning rate.")
    parser.add_argument(
        "--num_steps", type=int, default=200000, help="length of training schedule."
    )
    parser.add_argument(
        "--image_size",
        type=int,
        nargs="+",
        default=[384, 768],
        help="size of the random image crops used during training.",
    )
    parser.add_argument(
        "--train_iters",
        type=int,
        default=22,
        help="number of updates to the disparity field in each forward pass.",
    )
    parser.add_argument(
        "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
    )

    # Validation parameters
    parser.add_argument(
        "--valid_iters",
        type=int,
        default=32,
        help="number of flow-field updates during validation forward pass",
    )

    # Architecure choices
    parser.add_argument(
        "--corr_implementation",
        choices=["reg", "alt", "reg_cuda", "alt_cuda"],
        default="reg",
        help="correlation volume implementation",
    )
    parser.add_argument(
        "--shared_backbone",
        action="store_true",
        help="use a single backbone for the context and feature encoders",
        default=False,
    )
    parser.add_argument(
        "--context_norm",
        type=str,
        default="batch",
        choices=["group", "batch", "instance", "none"],
        help="normalization of context encoder",
    )
    parser.add_argument(
        "--corr_levels",
        type=int,
        default=4,
        help="number of levels in the correlation pyramid",
    )
    parser.add_argument(
        "--corr_radius", type=int, default=4, help="width of the correlation pyramid"
    )
    parser.add_argument(
        "--n_downsample",
        type=int,
        default=2,
        help="resolution of the disparity field (1/2^K)",
    )
    parser.add_argument(
        "--slow_fast_gru",
        action="store_true",
        help="iterate the low-res GRUs more frequently",
    )
    parser.add_argument(
        "--n_gru_layers", type=int, default=3, help="number of hidden GRU levels"
    )
    parser.add_argument(
        "--hidden_dims",
        nargs="+",
        type=int,
        default=[128] * 3,
        help="hidden state and context dimensions",
    )
    parser.add_argument(
        "--confidence_score",
        type=float,
        default=0.5,
        help="confidence score for propagation",
    )
    parser.add_argument("--patchmatch_rounds", type=int, default=2)
    parser.add_argument("--num_neighbors", type=int, default=8)
    parser.add_argument("--look_up_before_propa", action="store_true")
    parser.add_argument("--max_disp", type=int, default=256)

    # Data augmentation
    parser.add_argument(
        "--img_gamma", type=float, nargs="+", default=None, help="gamma range"
    )
    parser.add_argument(
        "--saturation_range",
        type=float,
        nargs="+",
        default=None,
        help="color saturation",
    )
    parser.add_argument(
        "--do_flip",
        default=False,
        choices=["h", "v"],
        help="flip the images horizontally or vertically",
    )
    parser.add_argument(
        "--spatial_scale",
        type=float,
        nargs="+",
        default=[0, 0],
        help="re-scale the images randomly",
    )
    parser.add_argument(
        "--noyjitter",
        action="store_true",
        help="don't simulate imperfect rectification",
    )
    args = parser.parse_args()

    torch.manual_seed(1234)
    np.random.seed(1234)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    )

    Path("checkpoints").mkdir(exist_ok=True, parents=True)

    train(args)
