# --------------------------------------------------------
# training code for CUT3R
# --------------------------------------------------------
# References:
# DUSt3R: https://github.com/naver/dust3r
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import sys
import time
import math
from collections import defaultdict
from pathlib import Path
from typing import Sized

import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

torch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12

from dust3r.architectures.ARCroco3DStereo import PreTrainedModel, ARCroco3DStereo, ARCroco3DStereoConfig, strip_module, inf
from dust3r.architectures.ModelUnifiedSimple import ModelUnifiedSimple



from dust3r.datasets import get_data_loader
from dust3r.losses import *  # noqa: F401, needed when loading the model
from dust3r.inference import loss_of_one_batch, loss_of_one_batch_tbptt  # noqa
from dust3r.viz import colorize
from dust3r.utils.render import get_render_results
import dust3r.utils.path_to_croco  # noqa: F401
import croco.utils.misc as misc  # noqa
from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler  # noqa

import hydra
from omegaconf import OmegaConf
import logging
import pathlib
from tqdm import tqdm
import random
import builtins
import shutil

from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.logging import get_logger
from datetime import timedelta
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy("file_system")

printer = get_logger(__name__, log_level="DEBUG")
import csv

def setup_for_distributed(accelerator: Accelerator):
    """
    This function disables printing when not in master process
    """
    builtin_print = builtins.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        force = force or (accelerator.num_processes > 8)
        if accelerator.is_main_process or force:
            now = datetime.datetime.now().time()
            builtin_print("[{}] ".format(now), end="")  # print with time stamp
            builtin_print(*args, **kwargs)

    builtins.print = print



def train(args):


    if args.wandb_logger is not False:
        accelerator = Accelerator(
            gradient_accumulation_steps=args.accum_iter,
            mixed_precision="bf16",
            kwargs_handlers=[
                DistributedDataParallelKwargs(find_unused_parameters=True),
                InitProcessGroupKwargs(timeout=timedelta(seconds=6000)),
            ],
            log_with="wandb"
        )
    else:
        accelerator = Accelerator(
            gradient_accumulation_steps=args.accum_iter,
            mixed_precision="bf16",
            kwargs_handlers=[
                DistributedDataParallelKwargs(find_unused_parameters=True),
                InitProcessGroupKwargs(timeout=timedelta(seconds=6000)),
            ]
        )

    accelerator.init_trackers(  
        project_name="cuter",
        config={"exp_name": args["exp_name"],
                "model": args["model"],
                "pretrained": args["pretrained"],
                "num_views": args["num_views"],
                "num_test_views": args["num_test_views"],
                "fixed_length": args["fixed_length"],
                "batch_size": args["batch_size"],
                },
    )


    device = accelerator.device
    setup_for_distributed(accelerator)

    printer.info("output_dir: " + args.output_dir)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    printer.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))

    # fix the seed
    seed = args.seed + accelerator.state.process_index
    printer.info(
        f"Setting seed to {seed} for process {accelerator.state.process_index}"
    )
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = args.benchmark

    # training dataset and loader
    printer.info("Building train dataset %s", args.train_dataset)
    #  dataset and loader
    data_loader_train = build_dataset(
        args.train_dataset,
        args.batch_size,
        args.num_workers,
        accelerator=accelerator,
        test=False,
        fixed_length=args.fixed_length
    )
    printer.info("Building test dataset %s", args.test_dataset)
    data_loader_test = {
        dataset.split("(")[0]: build_dataset(
            dataset,
            args.batch_size,
            args.num_workers,
            accelerator=accelerator,
            test=True,
            fixed_length=True
        )
        for dataset in args.test_dataset.split("+")
    }

    # model
    printer.info("Loading model: %s", args.model)
    model: PreTrainedModel = eval(args.model)
    printer.info(f"All model parameters: {sum(p.numel() for p in model.parameters())}")
    printer.info(
        f"Encoder parameters: {sum(p.numel() for p in model.enc_blocks.parameters())}"
    )
    printer.info(
        f"Decoder parameters: {sum(p.numel() for p in model.dec_blocks.parameters())}"
    )

    test_criterion = eval(args.test_criterion or args.criterion).to(device)

    model.to(device)

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    if args.long_context:
        model.fixed_input_length = False

    if args.raymap_only:
        model.raymap_only = True
    elif args.rgb_only:
        model.rgb_only = True

    if args.pretrained and not args.resume:
        printer.info(f"Loading pretrained: {args.pretrained}")
        ckpt = torch.load(args.pretrained, map_location=device)
        load_only_encoder = getattr(args, "load_only_encoder", False)
        if load_only_encoder:
            filtered_state_dict = {
                k: v
                for k, v in ckpt["model"].items()
                if "enc_blocks" in k or "patch_embed" in k
            }
            printer.info(
                model.load_state_dict(strip_module(filtered_state_dict), strict=False)
            )
        else:
            printer.info(
                model.load_state_dict(strip_module(ckpt["model"]), strict=False)
            )
        del ckpt  # in case it occupies memory

    # accelerator.even_batches = False
    model = accelerator.prepare(model)

    log_writer = None

    if "test_masked" in args and args.test_masked:
        test_masks = ["none", "pose", "intr", "depth", "pose+depth", "pose+intr", "intr+depth", "pose+intr+depth"]
    else:
        test_masks = [None]

    
    csv_path = os.path.join(args.output_dir, "test_stats.csv")
    csv_rows = []
    for test_name, testset in data_loader_test.items():
        # print(f"test_name: {test_name}")
        # exit()

        test_stats = {}

        for test_mask in test_masks:
            stats = test_one_epoch(
                model,
                test_criterion,
                testset,
                accelerator,
                device,
                1,
                log_writer=log_writer,
                args=args,
                prefix=test_name,
                modality_mask=test_mask
            )
            test_stats[test_name] = stats

            
            for sequence_name, metrics in test_stats.items():
                filtered_metrics = {
                    k: v for k,v in metrics.items() if k.endswith('_avg')
                }
                for metric_name, value in filtered_metrics.items():
                    csv_rows.append({
                        'sequence': sequence_name,
                        'metric': metric_name,
                        'value': value,
                        'modality_mask': test_mask
                    })
            
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=['sequence', 'metric', 'value', 'modality_mask'])
            writer.writeheader()
            writer.writerows(csv_rows)

        # print(f"test_stats: {test_stats}")




def build_dataset(dataset, batch_size, num_workers, accelerator, test=False, fixed_length=False):
    split = ["Train", "Test"][test]
    printer.info(f"Building {split} Data loader for dataset: {dataset}")
    loader = get_data_loader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_mem=True,
        shuffle=not (test),
        drop_last=not (test),
        accelerator=accelerator,
        fixed_length=fixed_length
    )
    return loader

@torch.no_grad()
def test_one_epoch(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    data_loader: Sized,
    accelerator: Accelerator,
    device: torch.device,
    epoch: int,
    args,
    log_writer=None,
    prefix="test",
    modality_mask=None
):
    print(f"modality_mask: {modality_mask}")

    model.eval()
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))
    header = "Test Epoch: [{}]".format(epoch)

    if log_writer is not None:
        printer.info("log_dir: {}".format(log_writer.log_dir))

    if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
        data_loader.dataset.set_epoch(0)
    if (
        hasattr(data_loader, "batch_sampler")
        and hasattr(data_loader.batch_sampler, "batch_sampler")
        and hasattr(data_loader.batch_sampler.batch_sampler, "set_epoch")
    ):
        data_loader.batch_sampler.batch_sampler.set_epoch(0)

    for _, batch in enumerate(
        metric_logger.log_every(data_loader, args.print_freq, accelerator, header)
    ):
        
        truth_mask = torch.ones(batch[0]['pose_mask'].shape, dtype=torch.bool)
        if modality_mask=='pose':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['pose_mask']=truth_mask
        elif modality_mask=='intr':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['intr_mask']=truth_mask
        elif modality_mask=='depth':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['depth_mask']=truth_mask
        elif modality_mask=='pose+depth':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['pose_mask']=truth_mask
                batch[seq_idx]['depth_mask']=truth_mask
        elif modality_mask=='pose+intr':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['pose_mask']=truth_mask
                batch[seq_idx]['intr_mask']=truth_mask
        elif modality_mask=='intr+depth':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['intr_mask']=truth_mask
                batch[seq_idx]['depth_mask']=truth_mask
        elif modality_mask=='pose+intr+depth':
            for seq_idx in range(len(batch)):
                batch[seq_idx]['pose_mask']=truth_mask
                batch[seq_idx]['intr_mask']=truth_mask
                batch[seq_idx]['depth_mask']=truth_mask

        result = loss_of_one_batch(
            batch,
            model,
            criterion,
            accelerator,
            symmetrize_batch=False,
            use_amp=bool(args.amp),
        )

        loss_value, loss_details = result["loss"]  # criterion returns two values
        metric_logger.update(loss=float(loss_value), **loss_details)

    printer.info("Averaged stats: %s", metric_logger)

    aggs = [("avg", "global_avg"), ("med", "median")]
    results = {
        f"{k}_{tag}": getattr(meter, attr)
        for k, meter in metric_logger.meters.items()
        for tag, attr in aggs
    }

    if log_writer is not None:
        for name, val in results.items():
            if isinstance(val, torch.Tensor):
                if val.ndim > 0:
                    continue
            if isinstance(val, dict):
                continue
            log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch)
            accelerator.log({f"{prefix}_{name}" : val}, 1000*epoch)

        depths_self, gt_depths_self = get_render_results(
            batch, result["pred"], self_view=True
        )
        depths_cross, gt_depths_cross = get_render_results(
            batch, result["pred"], self_view=False
        )
        for k in range(len(batch)):
            loss_details[f"self_pred_depth_{k+1}"] = depths_self[k].detach().cpu()
            loss_details[f"self_gt_depth_{k+1}"] = gt_depths_self[k].detach().cpu()
            loss_details[f"pred_depth_{k+1}"] = depths_cross[k].detach().cpu()
            loss_details[f"gt_depth_{k+1}"] = gt_depths_cross[k].detach().cpu()

    del loss_details, loss_value, batch
    torch.cuda.empty_cache()

    return results

@hydra.main(
    version_base=None,
    config_path=str(os.path.dirname(os.path.abspath(__file__))) + "/../config",
    config_name="train.yaml",
)
def run(cfg: OmegaConf):
    OmegaConf.resolve(cfg)
    logdir = pathlib.Path(cfg.logdir)
    logdir.mkdir(parents=True, exist_ok=True)
    train(cfg)


if __name__ == "__main__":
    run()
