import argparse
import logging
import math
import os
import random
from pathlib import Path

import accelerate
import datasets
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader

import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, Swinv2Model, AutoImageProcessor

from transformers.utils import ContextManagers

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, VQModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.activations import get_activation

import dataLoader
from dataLoader.KITTI_dataloader import load_train_data
from dataLoader.CVUSA_batch import CVUSAVal, CVUSATrain
from dataLoader.CVACT_batch import CVACTVal, CVACTTrain

from pano_projection import cvusa_grd2aer, cvact_grd2aer

from pipeline.pipeline_sat2grd_ldm import StableDiffusionSat2GrdPipeline, Projection
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
import kornia
from einops import rearrange

#metrics
import utils.pytorch_ssim as ssim
import lpips

if is_wandb_available():
    import wandb
    
check_min_version("0.17.0.dev0")

logger = get_logger(__name__, log_level="INFO")


def log_validation(vae, unet, valid_dataloader, image_processor, image_encoder, projection_net, args, accelerator, weight_dtype, epoch):
    logger.info("Running validation... ")

    pipeline = StableDiffusionSat2GrdPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        vae=accelerator.unwrap_model(vae),
        unet=accelerator.unwrap_model(unet),
        projection=accelerator.unwrap_model(projection_net),
        safety_checker=None,
        revision=args.revision,
        torch_dtype=weight_dtype,
    )
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    if args.enable_xformers_memory_efficient_attention:
        pipeline.enable_xformers_memory_efficient_attention()

    if args.seed is None:
        generator = None
    else:
        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
        
    toPIL = transforms.ToPILImage()
    mean_psnr = 0
    mean_ssim = 0
    mean_lpips = 0
    ssim_loss = ssim.SSIM(window_size = 11)
    lpips_loss_vgg = lpips.LPIPS(net='vgg').cuda()
    lpips_loss_vgg.requires_grad_(False)
    condition_imgs = []
    images = []
    gt_imgs = []
    diff_maps = []

    for step, batch in enumerate(valid_dataloader):
        with torch.autocast("cuda"):
                                    
            sat_map, grd_img, _, _, _, _ = batch
            sat_gt =  sat_map.to(weight_dtype).to(accelerator.device)
            grd_gt = grd_img.to(weight_dtype).to(accelerator.device)
            grd_img = (grd_img).to(weight_dtype).to(accelerator.device)
            sat_map = (sat_map * 2. - 1.).to(weight_dtype).to(accelerator.device)
                        
            batch_size, _, ori_grdH, ori_grdW = sat_map.shape
            
            processed_grd_img = image_processor(grd_img, do_rescale=False, return_tensors="pt").pixel_values
            grd_features = image_encoder(processed_grd_img.to(weight_dtype).to(accelerator.device), output_hidden_states=True).hidden_states
            img_feat_32x128 = grd_features[0]
            img_feat_8x32 = grd_features[2] 
            img_feat_4x16 = grd_features[4] 
            img_feat_4x16 = rearrange(img_feat_4x16, "b (h w) c -> b h w c", h=4, w=16)
            img_feat_8x32 = rearrange(img_feat_8x32, "b (h w) c -> b h w c", h=8, w=32)
            img_feat_32x128 = rearrange(img_feat_32x128, "b (h w) c -> b h w c", h=32, w=128)
            grd_feat_proj_4x16 = cvusa_grd2aer(img_feat_4x16.cpu().float(), 4, 16, 32)
            grd_feat_proj_8x32 = cvusa_grd2aer(img_feat_8x32.cpu().float(), 8, 32, 32)
            grd_feat_proj_32x128 = cvusa_grd2aer(img_feat_32x128.cpu().float(), 32, 128, 32)
            grd_feat_proj_aggr = torch.cat([grd_feat_proj_4x16, grd_feat_proj_8x32, grd_feat_proj_32x128], dim=3).float().to(accelerator.device)
            grd_feat_proj_aggr = rearrange(grd_feat_proj_aggr, "b h w c -> b (h w) c").float().to(accelerator.device)

            grd_img_emb = projection_net(grd_feat_proj_aggr)

            # image = pipeline(prompt_embeds=cond_emb.float(), latents=latents, num_inference_steps=50, generator=generator).images[0]
            img_pred = pipeline(prompt_embeds=grd_img_emb.float(), height=ori_grdH, width=ori_grdW, num_inference_steps=50, generator=generator, output_type="pt").images

            diff_map = torch.abs(img_pred.float() - sat_gt.float())
            
            mse_loss = F.mse_loss(img_pred.float(), sat_gt.float(), reduction="mean")
            psnr_val = 10 * torch.log10(1. / mse_loss)
            mean_psnr += psnr_val
            
            ssim_val = ssim_loss(img_pred.float(), sat_gt.float())
            mean_ssim += ssim_val
            
            lpips_val = lpips_loss_vgg(img_pred.float(), sat_gt.float())
            lpips_val = torch.mean(lpips_val)
            mean_lpips += lpips_val
        condition_imgs.append(toPIL(grd_gt[0]))
        gt_imgs.append(toPIL(sat_gt[0]))
        images.append(toPIL(img_pred[0]))
        diff_maps.append(toPIL(diff_map[0]))
        # position.append({"center_x": grid_info["center_x"], "center_y":grid_info["center_y"]})
        if step >= 20:
            break
        
    mean_psnr /= 20
    mean_ssim /= 20
    mean_lpips /= 20

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
        elif tracker.name == "wandb":
            tracker.log(
                {
                    "gt_img": [
                        wandb.Image(image, caption=f"image {i}")
                        for i, image in enumerate(gt_imgs)
                    ],
                    "validation": [
                        wandb.Image(image, caption=f"image {i}")
                        for i, image in enumerate(images)
                    ],
                    "diff_map": [
                        wandb.Image(image, caption=f"image {i}")
                        for i, image in enumerate(diff_maps)
                    ],
                    "cond": [
                        wandb.Image(image, caption=f"image {i}")
                        for i, image in enumerate(condition_imgs)
                    ],
                    "psnr": mean_psnr, 
                    "ssim_val": mean_ssim, 
                    "lpips_val": mean_lpips,
                    "epoch": epoch,
                }
            )
        else:
            logger.warn(f"image logging not implemented for {tracker.name}")

    del pipeline
    torch.cuda.empty_cache()


def parse_args():
    parser = argparse.ArgumentParser(description="Training a conditional ldm for satellite diffusion.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="stabilityai/stable-diffusion-2-1",
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--model_config_name_or_path",
        type=str,
        default="./configs/ldm_sat2grd.config",
        help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
    )
    parser.add_argument(
        "--logger",
        type=str,
        default="wandb",
        choices=["tensorboard", "wandb"],
        help=(
            "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
            " for experiment tracking and logging of model metrics and model checkpoints"
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--non_ema_revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
            " remote repository specified with --pretrained_model_name_or_path."
        ),
    )
    parser.add_argument(
        "--validation_prompts",
        type=str,
        default=None,
        nargs="+",
        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
    )
    parser.add_argument("--save_images_epochs", type=int, default=1, help="How often to save images during training.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="position_embedding_conditional_satellite_ldm_512x512",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=(
            "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
            " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
            " for more docs"
        ),
    )
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )

    parser.add_argument(
        "--report_to",
        type=str,
        default="wandb",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")

    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    parser.add_argument(
        "--freeze_model",
        type=str,
        default="crossattn_kv",
        choices=["crossattn_kv", "crossattn"],
        help="crossattn to enable fine-tuning of all params in the cross attention",
    )
    parser.add_argument(
        "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument(
        "--valid_batch_size", type=int, default=1, help="Batch size (per device) for the validation dataloader."
    )
    parser.add_argument("--num_train_epochs", type=int, default=10)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--input_pertubation", type=float, default=0.1, help="The scale of input pretubation. Recommended 0.1."
    )
    parser.add_argument(
        "--snr_gamma",
        type=float,
        default=None,
        help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
        "More details here: https://arxiv.org/abs/2303.09556.",
    )
    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default="position_embedding_conditional_satellite_ldm_512x512",
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=4000,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--resolution", 
        type=int, 
        default=512, 
        help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
        " resolution"
        ),
    )
    parser.add_argument(
        "--grid_condition_dim", type=int, default=32, help="satellite grid coordinates dimension for condition input"
    )
    parser.add_argument(
        "--data_random_crop", action="store_true", help="whether randomly crop satellite image patches when retrieving data, false means do center crop"
    )
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    # Sanity checks
    # if args.dataset_name is None and args.train_data_dir is None:
        # raise ValueError("Need either a dataset name or a training folder.")

    # default to using the same revision for the non-ema model if not specified
    if args.non_ema_revision is None:
        args.non_ema_revision = args.revision

    return args

def compute_snr(timesteps, noise_scheduler):
    """
    Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
    """
    alphas_cumprod = noise_scheduler.alphas_cumprod
    sqrt_alphas_cumprod = alphas_cumprod**0.5
    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

    # Expand the tensors.
    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

    # Compute SNR.
    snr = (alpha / sigma) ** 2
    return snr
            
def main(args):
    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
    
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_dir=logging_dir,
        project_config=accelerator_project_config,
    )
    
    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
        import wandb
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()
        
    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)
        
    # Load noise scheduler and models.
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
    image_encoder = Swinv2Model.from_pretrained("pretrained/swinv2-tiny-patch4-window8-256")    
    image_encoder.requires_grad_(False)
    # projection = None
    projection = Projection()
    torch.nn.init.eye_(list(projection.parameters())[0][:768, :1248])
    torch.nn.init.zeros_(list(projection.parameters())[1])
    projection.requires_grad_(True)
    
    # last_feat_proj = torch.nn.Linear()
    
    def deepspeed_zero_init_disabled_context_manager():
        """
        returns either a context list that includes one that will disable zero.Init or an empty context list
        """
        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
        if deepspeed_plugin is None:
            return []

        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
    
    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
        )   
        # vae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae", revision=args.revision)
    vae.requires_grad_(False)

    
    unet = UNet2DConditionModel.from_config(args.model_config_name_or_path)
    unet.requires_grad_(True)

    # Create EMA for the unet.
    if args.use_ema:
        ema_unet = UNet2DConditionModel.from_config(args.model_config_name_or_path)
        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
    
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        
    vae.to(accelerator.device, dtype=weight_dtype) 
    image_encoder.to(accelerator.device, dtype=weight_dtype)
    # unet.to(accelerator.device, dtype=weight_dtype)
    
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")
        
        # `accelerate` 0.16.0 will have better support for customized saving
    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(models, weights, output_dir):
            if args.use_ema:
                ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
                
            for i, model in enumerate(models):
                if type(model) is UNet2DConditionModel:
                    model.save_pretrained(os.path.join(output_dir, "unet"))
                elif type(model) is Projection:
                    model.save_pretrained(os.path.join(output_dir, "projection"))
              
                # make sure to pop weight so that corresponding model is not saved again
                weights.pop()

        def load_model_hook(models, input_dir):
            if args.use_ema:
                load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
                ema_unet.load_state_dict(load_model.state_dict())
                ema_unet.to(accelerator.device)
                del load_model

            for i in range(len(models)):
                model = models.pop()
                if type(model) is UNet2DConditionModel:
                    load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
                    model.register_to_config(**load_model.config)
                    model.load_state_dict(load_model.state_dict())
                elif type(model) is Projection:
                    load_model = Projection.from_pretrained(input_dir, subfolder="projection")
                    model.register_to_config(**load_model.config)
                    model.load_state_dict(load_model.state_dict())

                del load_model
            #     # pop models so that they are not loaded again


        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)
        
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    
    

    train = CVACTTrain(pathdir=args.train_data_dir)
    train_dataloader = DataLoader(train, batch_size=args.train_batch_size, shuffle=True, pin_memory=True,
                                num_workers=16, drop_last=False)
    
    val_dataloader = DataLoader(train, batch_size=args.valid_batch_size, shuffle=True, pin_memory=True,
                                num_workers=16, drop_last=False)
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        
    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True
    
    #init optimizer
    optimizer_cls = torch.optim.AdamW
    params_to_optimize = []
    params_to_optimize.append({"params": unet.parameters()})
    params_to_optimize.append({"params": projection.parameters()})

    optimizer = optimizer_cls(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
    
    # init lr_scheduler
    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    # Prepare everything with our `accelerator`.
    unet, projection, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, projection, optimizer, train_dataloader, lr_scheduler
    )

    if args.use_ema:
        ema_unet.to(accelerator.device)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        tracker_config.pop("validation_prompts")
        accelerator.init_trackers(args.tracker_project_name, tracker_config)
    
    #ready to train.
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0
    
    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
    
    for epoch in range(first_epoch, args.num_train_epochs):
        unet.train()
        projection.train()

        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            break
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
            
            with accelerator.accumulate([unet, projection]):
                # Convert images to latent space
                sat_map, grd_img, _, _, _, _ = batch
                grd_img = (grd_img).to(weight_dtype).to(accelerator.device)
                sat_map = (sat_map * 2. - 1.).to(weight_dtype).to(accelerator.device)
                            

                batch_size, _, ori_grdH, ori_grdW = grd_img.shape
                
                #prepare multi-channel latents
                with torch.torch.no_grad():
                    latents = vae.encode(sat_map).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor
                # print(latents.shape)
                
                # left_camera_k = grid_info["left_camera_k"]
                
                processed_grd_img = image_processor(grd_img, do_rescale=False, return_tensors="pt").pixel_values
                grd_features = image_encoder(processed_grd_img.to(weight_dtype).to(accelerator.device), output_hidden_states=True).hidden_states
                img_feat_32x128 = grd_features[0]
                img_feat_8x32 = grd_features[2] 
                img_feat_4x16 = grd_features[4] 
                img_feat_4x16 = rearrange(img_feat_4x16, "b (h w) c -> b h w c", h=4, w=16)
                img_feat_8x32 = rearrange(img_feat_8x32, "b (h w) c -> b h w c", h=8, w=32)
                img_feat_32x128 = rearrange(img_feat_32x128, "b (h w) c -> b h w c", h=32, w=128)
                grd_feat_proj_4x16 = cvusa_grd2aer(img_feat_4x16.cpu().float(), 4, 16, 32)
                grd_feat_proj_8x32 = cvusa_grd2aer(img_feat_8x32.cpu().float(), 8, 32, 32)
                grd_feat_proj_32x128 = cvusa_grd2aer(img_feat_32x128.cpu().float(), 32, 128, 32)
                grd_feat_proj_aggr = torch.cat([grd_feat_proj_4x16, grd_feat_proj_8x32, grd_feat_proj_32x128], dim=3).float().to(accelerator.device)
                grd_feat_proj_aggr = rearrange(grd_feat_proj_aggr, "b h w c -> b (h w) c").float().to(accelerator.device)
                grd_img_emb = projection(grd_feat_proj_aggr)
                
                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                if args.noise_offset:
                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise
                    noise += args.noise_offset * torch.randn(
                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
                    )
                if args.input_pertubation:
                    new_noise = noise + args.input_pertubation * torch.randn_like(noise)
                    
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()
                
                if args.input_pertubation:
                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
                else:
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
                

                # with torch.cuda.amp.autocast():
                model_pred = unet(noisy_latents.float(), timesteps, encoder_hidden_states=grd_img_emb.float()).sample

                if args.snr_gamma is None:
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                else:
                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                    # This is discussed in Section 4.2 of the same paper.
                    snr = compute_snr(timesteps)
                    mse_loss_weights = (
                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                    )
                    # We first calculate the original loss. Then we mean over the non-batch dimensions and
                    # rebalance the sample-wise losses with their respective loss weights.
                    # Finally, we take the mean of the rebalanced loss.GradScaler
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
        
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                if args.use_ema:
                    ema_unet.step(unet.parameters())
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

            if global_step % args.checkpointing_steps == 0:
                if accelerator.is_main_process:
                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")
                    
            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        if accelerator.is_main_process:
            if epoch % args.save_images_epochs == 0 or epoch == args.num_train_epochs - 1:
                logger.info(
                    f"Running validation... \n"
                )
                if args.use_ema:
                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
                    ema_unet.store(unet.parameters())
                    ema_unet.copy_to(unet.parameters())
                    
                log_validation(vae, unet, val_dataloader, image_processor, image_encoder, projection, args, accelerator, weight_dtype, epoch)
                
                if args.use_ema:
                    # Switch back to the original UNet parameters.
                    ema_unet.restore(unet.parameters())

    # Create the pipeline using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        if args.use_ema:
            ema_unet.copy_to(unet.parameters())
            
        pipeline = StableDiffusionSat2GrdPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            vae=accelerator.unwrap_model(vae),
            unet=accelerator.unwrap_model(unet),
            projection=accelerator.unwrap_model(projection),
            safety_checker=None,
            revision=args.revision,
            torch_dtype=weight_dtype,
        )

        pipeline.save_pretrained(args.output_dir)


    accelerator.end_training()
                
if __name__ == "__main__":
    args = parse_args()

    main(args)