import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from transformers import Swinv2Model, AutoImageProcessor

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

from dataLoader.CVUSA_batch import CVUSAVal, CVUSATrain
from dataLoader.CVACT_batch import CVACTVal, CVACTTrain
from pano_projection import cvusa_grd2aer, cvact_grd2aer

from pipeline.pipeline_controlnet_sat2grd_prj import StableDiffusionControlNetPipeline, Projection
from models.controlnet_noemb import ControlNetModel

# from vgg_encoder import ControlNet_Cond_Encoder, project_grd_to_map
import kornia
from einops import rearrange

import utils.pytorch_ssim as ssim
import lpips


if is_wandb_available():
    import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")

logger = get_logger(__name__)

def log_validation(vae, image_processor, image_encoder, projection, tokenizer, text_encoder, unet, controlnet, args, accelerator, weight_dtype, val_dataloader):
    logger.info("Running validation... ")

    controlnet = accelerator.unwrap_model(controlnet)
    stable_diff_2_1 = "stabilityai/stable-diffusion-2-1"

    pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        stable_diff_2_1,
        vae=vae,
        image_encoder=image_encoder,
        projection=projection,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        safety_checker=None,
        revision=args.revision,
        torch_dtype=weight_dtype,
        text_encoder=text_encoder,
    )
    
    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
    # pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)
    
    if args.enable_xformers_memory_efficient_attention:
        pipeline.enable_xformers_memory_efficient_attention()

    if args.seed is None:
        generator = None
    else:
        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
        
    progress_bar = tqdm(total=len(val_dataloader), disable=not accelerator.is_local_main_process)
    progress_bar.set_description(f"step")
    
    toPIL = transforms.ToPILImage()
    mean_psnr = 0
    mean_ssim = 0
    mean_lpips_alex = 0
    mean_lpips_squeeze = 0
    mean_lpips = 0

    ssim_loss = ssim.SSIM(window_size = 11)
    lpips_loss = lpips.LPIPS(net='vgg').cuda()
    lpips_loss.requires_grad_(False)
    lpips_loss_alex = lpips.LPIPS(net='alex').cuda()
    lpips_loss_alex.requires_grad_(False)
    lpips_loss_squeeze = lpips.LPIPS(net='squeeze').cuda()
    lpips_loss_squeeze.requires_grad_(False)
    condition_imgs = []
    images = []
    gt_imgs = []
    diff_maps = []
    for step, batch in enumerate(val_dataloader):
        with torch.autocast("cuda"):
            sat_map, grd_img, _, _, _, _ = batch
            sat_gt =  sat_map.to(weight_dtype).to(accelerator.device)
            grd_gt = grd_img.to(weight_dtype).to(accelerator.device)
            grd_img = grd_img.to(weight_dtype).to(accelerator.device)
            sat_map = (sat_map * 2. - 1.).to(weight_dtype).to(accelerator.device)
            
            batch_size, _, ori_grdH, ori_grdW = grd_img.shape
            
            #start with empty text prompts, we assume this aligns with unconditional generation
            text_condition = ""
            text_condition =  [text_condition] * batch_size
            text_tokens = tokenizer(
                text_condition, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
            )
            encoder_hidden_states = text_encoder(text_tokens.input_ids.to(accelerator.device))[0]
            
            grd_features = image_encoder(grd_img.to(weight_dtype).to(accelerator.device), output_hidden_states=True).hidden_states
            img_feat_32x128 = grd_features[0]
            img_feat_8x32 = grd_features[2] 
            img_feat_4x16 = grd_features[4] 
            img_feat_4x16 = rearrange(img_feat_4x16, "b (h w) c -> b h w c", h=4, w=16)
            img_feat_8x32 = rearrange(img_feat_8x32, "b (h w) c -> b h w c", h=8, w=32)
            img_feat_32x128 = rearrange(img_feat_32x128, "b (h w) c -> b h w c", h=32, w=128)
            if args.dataset == "cvusa":
                grd_feat_proj_4x16 = cvusa_grd2aer(img_feat_4x16.cpu().float(), 4, 16, 32)
                grd_feat_proj_8x32 = cvusa_grd2aer(img_feat_8x32.cpu().float(), 8, 32, 32)
                grd_feat_proj_32x128 = cvusa_grd2aer(img_feat_32x128.cpu().float(), 32, 128, 32)
            else:
                grd_feat_proj_4x16 = cvact_grd2aer(img_feat_4x16.cpu().float(), 4, 16, 32)
                grd_feat_proj_8x32 = cvact_grd2aer(img_feat_8x32.cpu().float(), 8, 32, 32)
                grd_feat_proj_32x128 = cvact_grd2aer(img_feat_32x128.cpu().float(), 32, 128, 32)
            grd_feat_proj_aggr = torch.cat([grd_feat_proj_4x16, grd_feat_proj_8x32, grd_feat_proj_32x128], dim=3).float().to(accelerator.device)
            controlnet_image = projection(grd_feat_proj_aggr)
                
            controlnet_image = controlnet_image.permute(0, 3, 1, 2)

            images = []
            diff_maps = []
            # gt_imgs.append(toPIL(sat_gt[0]))
            for batch_num in range(batch_size):
                gt_save = toPIL(sat_gt[batch_num])
                gt_save.save(f"{args.output_dir}/gt_{step}_{batch_num}.jpg") 
                cond_save = toPIL(grd_gt[batch_num])
                cond_save.save(f"{args.output_dir}/cond_{step}_{batch_num}.jpg") 

            for i in range(args.num_validation_images):
                with torch.autocast("cuda"):
                    image = pipeline(
                        image=controlnet_image, num_inference_steps=50, width=256, height=256, generator=generator, prompt_embeds=encoder_hidden_states, guidance_scale=1.0, output_type="pt"
                    ).images
                    
                    images.append(toPIL(image[0]))
                        
                    diff_map = torch.abs(image.float() - sat_gt.float())
                    diff_maps.append(toPIL(diff_map[0]))

                    mse_loss = F.mse_loss(image.float(), sat_gt.float(), reduction="mean")
                    psnr_val = 10 * torch.log10(1. / mse_loss)
                    mean_psnr += psnr_val
                    
                    ssim_val = ssim_loss(image.float(), sat_gt.float())
                    mean_ssim += ssim_val
                    
                    lpips_val = lpips_loss(img_pred.float(), grd_gt.float())
                    lpips_val = torch.mean(lpips_val)
                    mean_lpips += lpips_val
                    lpips_alex_val = lpips_loss_alex(image.float(), sat_gt.float())
                    lpips_alex_val = torch.mean(lpips_alex_val)
                    mean_lpips_alex += lpips_alex_val
                    lpips_alex_squeeze_val = lpips_loss_squeeze(image.float(), sat_gt.float())
                    lpips_alex_squeeze_val = torch.mean(lpips_alex_squeeze_val)
                    mean_lpips_squeeze += lpips_alex_squeeze_val
                    
                for batch_num in range(batch_size):
                    val_save = toPIL(image[batch_num])
                    val_save.save(f"{args.output_dir}/sync_{step}_{batch_num}_sample_{i}.jpg")
                # images.append(toPIL(image[0]))
            progress_bar.update(1)
            logs = {"psnr": psnr_val.detach().item(), "ssim_val": ssim_val.detach().item(), "lpips_alex_val": lpips_alex_val.detach().item(),"lpips_squeeze_val": lpips_alex_squeeze_val.detach().item()}
            progress_bar.set_postfix(**logs)
            accelerator.log({"psnr": psnr_val, "ssim_val": ssim_val, "lpips_alex_val": lpips_alex_val, "lpips_squeeze_val": lpips_alex_squeeze_val}, step=step)
                
    mean_psnr /= (args.num_validation_images * len(val_dataloader))
    mean_ssim /= (args.num_validation_images * len(val_dataloader))
    mean_lpips /=  (args.num_validation_images * len(val_dataloader))
    mean_lpips_alex /= (args.num_validation_images * len(val_dataloader))
    mean_lpips_squeeze /= (args.num_validation_images * len(val_dataloader))
    
    logger.info(f"Task Validation {args.validation_task} Results:")
    logger.info(f"PSNR = {mean_psnr}")
    logger.info(f"SSIM = {mean_ssim}")
    logger.info(f"LPIPS = {mean_lpips}")
    logger.info(f"LPIPS_A = {mean_lpips_alex}")
    logger.info(f"LPIPS_S = {mean_lpips_squeeze}")
    
    # for tracker in accelerator.trackers:
    #     if tracker.name == "tensorboard":
    #         np_images = np.stack([np.asarray(img) for img in images])
    #         tracker.writer.add_images("validation", np_images, dataformats="NHWC")
    #     elif tracker.name == "wandb":
    #         tracker.log(
    #             {
    #                 "gt_img": [
    #                     wandb.Image(image, caption=f"image {i}")
    #                     for i, image in enumerate(gt_imgs)
    #                 ],
    #                 "validation": [
    #                     wandb.Image(image, caption=f"image {i}")
    #                     for i, image in enumerate(images)
    #                 ],
    #                 # "diff_map": [
    #                 #     wandb.Image(image, caption=f"image {i}")
    #                 #     for i, image in enumerate(diff_maps)
    #                 # ],
    #                 # "cond": [
    #                 #     wandb.Image(image, caption=f"image {i}")
    #                 #     for i, image in enumerate(condition_imgs)
    #                 # ],
    #                 "psnr": mean_psnr, 
    #                 "ssim_val": mean_ssim, 
    #                 "lpips_alex_val": mean_lpips_alex,
    #                 "lpips_squeeze_val": mean_lpips_squeeze,

    #             }
    #         )
    #     else:
    #         logger.warn(f"image logging not implemented for {tracker.name}")

    del pipeline
    torch.cuda.empty_cache()
    
            
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")
    
def parse_args():
    parser = argparse.ArgumentParser(description="Training a conditional ldm for satellite diffusion.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="stabilityai/stable-diffusion-2-1-base",
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--model_config_name_or_path",
        type=str,
        default="./configs/ldm_grd2sat.config",
        help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
    )
    parser.add_argument(
        "--validation_task",
        type=str,
        default="image_feat",
        choices=["image_feat", "multifeat_prj"],
        help="validation task, will choose the model type based on the task, e.g. ground image to satellite image (grd2sat), \
        ground image to satellite image and position grid (grd2satpos) or position grid to satellite image (pos2sat)"
    )
    
    parser.add_argument(
        "--dataset",
        type=str,
        default="cvusa",
        choices=["cvusa", "cvact"],
        help=(
            "training dataset"
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="position_embedding_conditional_satellite_ldm_512x512",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--valid_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=(
            "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
            " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
            " for more docs"
        ),
    )
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--logger",
        type=str,
        default="wandb",
        choices=["tensorboard", "wandb"],
        help=(
            "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
            " for experiment tracking and logging of model metrics and model checkpoints"
        ),
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--non_ema_revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
            " remote repository specified with --pretrained_model_name_or_path."
        ),
    )
    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    parser.add_argument(
        "--valid_batch_size", type=int, default=1, help="Batch size (per device) for the validation dataloader."
    )
    parser.add_argument(
        "--num_validation_images",
        type=int,
        default=4,
        help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--snr_gamma",
        type=float,
        default=None,
        help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
        "More details here: https://arxiv.org/abs/2303.09556.",
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default="position_embedding_conditional_satellite_ldm_512x512",
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="wandb",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--resolution", 
        type=int, 
        default=512, 
        help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
        " resolution"
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    # default to using the same revision for the non-ema model if not specified
    if args.non_ema_revision is None:
        args.non_ema_revision = args.revision

    return args


def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)
        
    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Load the tokenizer
    stable_diff_2_1 = "stabilityai/stable-diffusion-2-1"

    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            stable_diff_2_1,
            subfolder="tokenizer",
            revision=args.revision,
            use_fast=False,
        )

    # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(stable_diff_2_1, args.revision)
    
    text_encoder = text_encoder_cls.from_pretrained(
        stable_diff_2_1, subfolder="text_encoder", revision=args.revision
    )
    image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
    image_encoder = Swinv2Model.from_pretrained("pretrained/swinv2-tiny-patch4-window8-256")    
    image_encoder.requires_grad_(False)
    
    projection = Projection.from_pretrained(args.pretrained_model_name_or_path, subfolder="projection", from_tf=True)
    
    vae = AutoencoderKL.from_pretrained(stable_diff_2_1, subfolder="vae", revision=args.revision)    
    controlnet = ControlNetModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="controlnet", from_tf=True)
    unet = UNet2DConditionModel.from_pretrained(stable_diff_2_1, subfolder="unet", from_tf=True)

    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    projection.requires_grad_(False)
    vae.requires_grad_(False)
    controlnet.requires_grad_(False)
    unet.requires_grad_(False)
        
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    
    text_encoder.to(accelerator.device, dtype=weight_dtype) 
    vae.to(accelerator.device, dtype=weight_dtype) 
    projection.to(accelerator.device, dtype=weight_dtype)
    image_encoder.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device, dtype=weight_dtype)
    controlnet.to(accelerator.device, dtype=weight_dtype)
    
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers
            unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")
    
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    
    if args.dataset == "cvusa":
        cvusa_val = CVUSAVal(root=args.valid_data_dir, datalist= args.valid_data_dir + "splits/valid-19zl.csv", crop=True)
        val_dataloader = DataLoader(cvusa_val, batch_size=args.valid_batch_size, shuffle=False, pin_memory=True,
                        num_workers=16, drop_last=False)
    else:
        cvact_val = CVACTVal(pathdir=args.valid_data_dir)   
        val_dataloader = DataLoader(cvact_val, batch_size=args.valid_batch_size, shuffle=False, pin_memory=True,
                                num_workers=16, drop_last=False)

    
    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        accelerator.init_trackers(args.tracker_project_name, tracker_config)
        
    logger.info("***** Running validation *****")
    logger.info(f"  Num examples = {len(val_dataloader)}")
    logger.info(f"  Instantaneous batch size per device = {args.valid_batch_size}")
    
    
    if accelerator.is_main_process:
        logger.info(
            f"Running validation... \n"
        )
        
        image_logs = log_validation(vae, image_processor, image_encoder, projection, tokenizer, text_encoder, unet, controlnet, 
                                    args, accelerator, weight_dtype, val_dataloader,)

    
    accelerator.wait_for_everyone()
    accelerator.end_training()

            
if __name__ == "__main__":
    args = parse_args()
    main(args)