from __future__ import print_function, division

import argparse
import logging
import numpy as np
from pathlib import Path
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim

# from core.raft_stereo import RAFTStereo
from core.raft_stereo_modified4 import RAFTStereo_modified
# from core.crestereo import CREStereo

# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
from evaluate_stereo import *
import core.stereo_datasets as datasets
import torch.nn.functional as F

import torch.distributed as dist
import torch.multiprocessing as mp
import torch.utils.data as data
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self):
            pass

        def scale(self, loss):
            return loss

        def unscale_(self, optimizer):
            pass

        def step(self, optimizer):
            optimizer.step()

        def update(self):
            pass

def setup(rank, world_size):
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    # 销毁进程组
    dist.destroy_process_group()



def sequence_loss(
    flow_preds, flow_gt, valid, dxy_preds, dxy_gt, loss_gamma=0.9, max_flow=700, start=0
):
    """Loss function defined over sequence of flow predictions"""

    n_predictions = len(flow_preds)
    assert n_predictions >= 1
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()

    # exclude extremly large displacements
    valid = ((valid >= 0.5) & (mag < max_flow)).unsqueeze(1)
    assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
    assert valid.shape == flow_preds[-1].shape, [valid.shape, flow_preds[-1].shape]
    assert not torch.isinf(flow_gt[valid.bool()]).any()

    for i in range(n_predictions):
        assert (
            not torch.isnan(flow_preds[i]).any()
            and not torch.isinf(flow_preds[i]).any()
        )
        # adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
        adjusted_loss_gamma = 0.8
        i_weight = adjusted_loss_gamma ** (n_predictions - i - 1 + start)
        i_loss = (flow_preds[i] - flow_gt).abs()
        assert i_loss.shape == valid.shape, [
            i_loss.shape,
            valid.shape,
            flow_gt.shape,
            flow_preds[i].shape,
        ]
        if dxy_gt is None:
            flow_loss += i_weight * (i_loss[valid.bool()].mean())
            epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
            epe = epe.view(-1)[valid.view(-1)]

            metrics = {
                "epe": epe.mean().item(),
                "1px": (epe < 1).float().mean().item(),
                "3px": (epe < 3).float().mean().item(),
                "5px": (epe < 5).float().mean().item(),
            }

    return flow_loss, metrics


def fetch_optimizer(args, model):
    """Create the optimizer and learning rate scheduler"""
    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
    )

    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        args.lr,
        args.num_steps + 100,
        pct_start=0.01,
        cycle_momentum=False,
        anneal_strategy="linear",
    )

    return optimizer, scheduler


class Logger:
    SUM_FREQ = 100

    def __init__(self, args,model, scheduler):
        self.model = model
        self.scheduler = scheduler
        self.total_steps = 0
        self.running_loss = {}
        self.writer = SummaryWriter(log_dir=args.tensorboard_file_path)

    def _print_training_status(self):
        metrics_data = [
            self.running_loss[k] / Logger.SUM_FREQ
            for k in sorted(self.running_loss.keys())
        ]
        training_str = "[{:6d}, {:10.7f}] ".format(
            self.total_steps + 1, self.scheduler.get_last_lr()[0]
        )
        metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)

        # print the training status
        print(
            f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
        )

        if self.writer is None:
            self.writer = SummaryWriter(log_dir="runs")

        for k in self.running_loss:
            self.writer.add_scalar(
                k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
            )
            self.running_loss[k] = 0.0

    def push(self, metrics):
        self.total_steps += 1

        for key in metrics:
            if key not in self.running_loss:
                self.running_loss[key] = 0.0

            self.running_loss[key] += metrics[key]

        if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
            self._print_training_status()
            self.running_loss = {}

    def write_dict(self, results):
        if self.writer is None:
            self.writer = SummaryWriter(log_dir="runs")

        for key in results:
            self.writer.add_scalar(key, results[key], self.total_steps)

    def close(self):
        self.writer.close()


def train(args,model,train_loader,optimizer,scheduler,rank):
    total_steps = 0
    global_batch_num = 0
    logger = Logger(args,model, scheduler)
    scaler = GradScaler(enabled=args.mixed_precision)
    if args.resume is not None:
        assert args.resume.endswith(".pth")
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(args.restore_ckpt)
        model.load_state_dict(checkpoint["model"], strict=True)
        optimizer.load_state_dict(checkpoint["optimizer"], strict=True)
        scaler.load_state_dict(checkpoint["scaler"], strict=True)
        scheduler.load_state_dict(checkpoint["scheduler"], strict=True)
        total_steps = checkpoint["epoch"] - 1
        global_batch_num = total_steps
        logging.info(f"Done loading checkpoint")
    if args.restore_ckpt is not None:
        assert args.restore_ckpt.endswith(".pth")
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(args.restore_ckpt)
        model.load_state_dict(checkpoint, strict=True)
        logging.info(f"Done loading checkpoint")

    model.train()
    model.module.freeze_bn()  # We keep BatchNorm frozen
    validation_frequency = 5000
    should_keep_training=True
    num_epochs=args.num_steps//len(train_loader)+1
    for epoch in range(num_epochs): 
        train_loader.sampler.set_epoch(epoch)
        for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            image1, image2, flow, valid = [x.cuda() for x in data_blob]

            assert model.training
            flow_predictions, gradient_predictions = model(
                image1,
                image2,
                iters=args.train_iters,
                patchmatch_rounds=args.patchmatch_rounds,
            )
            # flow_predictions, gradient_predictions = model(image1, image2)
            assert model.training
            loss1, _ = sequence_loss(flow_predictions[0], flow, valid, None, None)
            loss2, _ = sequence_loss(flow_predictions[1], flow, valid, None, None)
            loss3, metrics = sequence_loss(flow_predictions[2], flow, valid, None, None)
            loss = loss1 + loss2 + loss3

            global_batch_num += 1
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            if rank == 0:
                logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
                logger.writer.add_scalar(
                    f"learning_rate", optimizer.param_groups[0]["lr"], global_batch_num
                )
                logger.push(metrics)
                if total_steps % validation_frequency == validation_frequency - 1:
                    to_save = {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": total_steps + 1,
                        "scaler": scaler.state_dict(),
                        "args": args,
                        "scheduler,": scheduler.state_dict(),
                    }
                    save_path = Path("checkpoints/%d_%s.pth" % (total_steps + 1, args.name))
                    print(f"Saving file {save_path.absolute()}")
                    torch.save(to_save, save_path)

                    results = validate_things(
                        model.module,
                        iters=args.valid_iters,
                        patchmatch_rounds=args.patchmatch_rounds,
                    )

                    logger.write_dict(results)

                    model.train()
                    model.module.freeze_bn()

                total_steps += 1

                if total_steps > args.num_steps:
                    should_keep_training = False
                    break
        
        if not should_keep_training:
            break

        if len(train_loader) >= 1000:
            save_path = Path(
                "checkpoints/%d_epoch_%s.pth.gz" % (total_steps + 1, args.name)
            )
            print(f"Saving file {save_path}")
            torch.save(model.state_dict(), save_path)

    print("FINISHED TRAINING")
    logger.close()
    PATH = "checkpoints/%s.pth" % args.name
    torch.save(model.state_dict(), PATH)

    return PATH

def main_worker(rank, world_size,args):
    train_dataset = datasets.fetch_dataset(args,rank)
    setup(rank, world_size)
    # model = nn.DataParallel(CREStereo(args))
    model=RAFTStereo_modified(args).cuda()
    if rank==0:
        print(args)
        print("Parameter Count: %d" % count_parameters(model))
    model = DDP(model, device_ids=[torch.cuda.current_device()],find_unused_parameters=True)

    sampler = DistributedSampler(train_dataset,shuffle=True)
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=args.batch_size//world_size,
        pin_memory=True,
        num_workers=int(os.environ.get("SLURM_CPUS_PER_TASK", 6)) - 2,
        drop_last=True,
        sampler=sampler
    )
    optimizer, scheduler = fetch_optimizer(args, model)
    train(args,model,train_loader,optimizer,scheduler,rank)
    dist.barrier()
    cleanup()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", default="EAI-Stereo", help="name your experiment")
    parser.add_argument("--restore_ckpt", help="restore checkpoint")
    parser.add_argument("--resume", help="resume checkpoint")
    parser.add_argument(
        "--mixed_precision", action="store_true", help="use mixed precision"
    )
    parser.add_argument("--tensorboard_file_path")
    # Training parameters
    parser.add_argument(
        "--batch_size", type=int, default=8, help="batch size used during training."
    )
    parser.add_argument(
        "--train_datasets", nargs="+", default=["sceneflow"], help="training datasets."
    )
    parser.add_argument("--lr", type=float, default=0.0002, help="max learning rate.")
    parser.add_argument(
        "--num_steps", type=int, default=200000, help="length of training schedule."
    )
    parser.add_argument(
        "--image_size",
        type=int,
        nargs="+",
        default=[384, 768],
        help="size of the random image crops used during training.",
    )
    parser.add_argument(
        "--train_iters",
        type=int,
        default=22,
        help="number of updates to the disparity field in each forward pass.",
    )
    parser.add_argument(
        "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
    )

    # Validation parameters
    parser.add_argument(
        "--valid_iters",
        type=int,
        default=32,
        help="number of flow-field updates during validation forward pass",
    )

    # Architecure choices
    parser.add_argument(
        "--corr_implementation",
        choices=["reg", "alt", "reg_cuda", "alt_cuda"],
        default="reg",
        help="correlation volume implementation",
    )
    parser.add_argument(
        "--shared_backbone",
        action="store_true",
        help="use a single backbone for the context and feature encoders",
        default=False,
    )
    parser.add_argument(
        "--context_norm",
        type=str,
        default="batch",
        choices=["group", "batch", "instance", "none"],
        help="normalization of context encoder",
    )
    parser.add_argument(
        "--corr_levels",
        type=int,
        default=4,
        help="number of levels in the correlation pyramid",
    )
    parser.add_argument(
        "--corr_radius", type=int, default=4, help="width of the correlation pyramid"
    )
    parser.add_argument(
        "--n_downsample",
        type=int,
        default=2,
        help="resolution of the disparity field (1/2^K)",
    )
    parser.add_argument(
        "--slow_fast_gru",
        action="store_true",
        help="iterate the low-res GRUs more frequently",
    )
    parser.add_argument(
        "--n_gru_layers", type=int, default=3, help="number of hidden GRU levels"
    )
    parser.add_argument(
        "--hidden_dims",
        nargs="+",
        type=int,
        default=[128] * 3,
        help="hidden state and context dimensions",
    )
    parser.add_argument(
        "--confidence_score",
        type=float,
        default=0.5,
        help="confidence score for propagation",
    )
    parser.add_argument("--patchmatch_rounds", type=int, default=2)
    parser.add_argument("--num_neighbors", type=int, default=8)
    parser.add_argument("--max_disp", type=int, default=256)
    parser.add_argument("--look_up_before_propa", action="store_true")

    # Data augmentation
    parser.add_argument(
        "--img_gamma", type=float, nargs="+", default=None, help="gamma range"
    )
    parser.add_argument(
        "--saturation_range",
        type=float,
        nargs="+",
        default=None,
        help="color saturation",
    )
    parser.add_argument(
        "--do_flip",
        default=False,
        choices=["h", "v"],
        help="flip the images horizontally or vertically",
    )
    parser.add_argument(
        "--spatial_scale",
        type=float,
        nargs="+",
        default=[0, 0],
        help="re-scale the images randomly",
    )
    parser.add_argument(
        "--noyjitter",
        action="store_true",
        help="don't simulate imperfect rectification",
    )
    args = parser.parse_args()

    torch.manual_seed(1234)
    np.random.seed(1234)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    )

    Path("checkpoints").mkdir(exist_ok=True, parents=True)

    world_size = 4 # 4块GPU
    # num_epochs = 10 # 总共训练10轮
    # 采用mp.spawn启动
    mp.spawn(main_worker, args=(world_size,args), nprocs=world_size, join=True)
