import os
import torch
import argparse
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*?Your .*? set is empty.*?")

from pipeline_stable_diffusion_3_reg import StableDiffusion3REGPipeline as StableDiffusion3Pipeline

import os
import sys
import json

import torch
from torch_fidelity_utils.defaults import DEFAULTS
from torch_fidelity_utils.helpers import process_deprecations
from torch_fidelity_utils.metrics import calculate_metrics
from torch_fidelity_utils.registry import (
    FEATURE_EXTRACTORS_REGISTRY,
    DATASETS_REGISTRY,
    SAMPLE_SIMILARITY_REGISTRY,
    INTERPOLATION_REGISTRY,
    NOISE_SOURCE_REGISTRY,
)

from utils import get_coco_image_ids_and_captions, eval_clip


def parse_args():
    parser = argparse.ArgumentParser(description="Stable Diffusion 3")
    parser.add_argument("--base_dir", type=str, default="exp")
    parser.add_argument("--num_inference_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=625)
    parser.add_argument("--guidance_scale", type=float, default=3.0)
    parser.add_argument("--beta_1", type=float, default=0.9)
    parser.add_argument("--beta_2", type=float, default=0.9)
    parser.add_argument("--gamma", type=float, default=0.3)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--percent", type=float, default=0.001)
    ########## fid ##########
    parser.add_argument("--anno_dir", type=str, default="/mnt/mydisk/_datasets/annotations/captions_val2017.json")
    parser.add_argument("--dataset_dir", type=str, default="/mnt/mydisk/_datasets/val2017")
    parser.add_argument("--generated_dir", type=str, default=None)
    ########## fid ##########
    ########## clip ##########
    parser.add_argument("--model_name_or_path", type=str, default="openai/clip-vit-base-patch16")
    parser.add_argument("--batch_size", type=int, default=256)
    ########## clip ##########
    ########## eval ##########
    parser.add_argument("--input1", type=str, default=DEFAULTS['input1'], help="First input: directory, registered dataset, or model path")
    parser.add_argument("--input2", type=str, default=DEFAULTS['input2'], help="Second input: directory, registered dataset, or model path")
    parser.add_argument("--gpu", type=str, default=None, help="Use CUDA (overrides CUDA_VISIBLE_DEVICES)")
    parser.add_argument("--cpu", action="store_true", help="Use CPU despite capabilities")
    parser.add_argument("--json", action="store_true", help="Print scores in JSON")
    parser.add_argument("--isc", action="store_true", help="Calculate ISC (Inception Score)")
    parser.add_argument("--fid", action="store_true", help="Calculate FID (Frechet Inception Distance)")
    parser.add_argument("--kid", action="store_true", help="Calculate KID (Kernel Inception Distance)")
    parser.add_argument("--prc", action="store_true", help="Calculate PRC (Precision and Recall)")
    parser.add_argument("--ppl", action="store_true", help="Calculate PPL (Perceptual Path Length)")
    parser.add_argument("--feature-extractor", type=str, default=DEFAULTS['feature_extractor'], help="Name of the feature extractor")
    parser.add_argument("--feature-layer-isc", type=str, default=DEFAULTS['feature_layer_isc'], help="Feature layer to use with ISC")
    parser.add_argument("--feature-layer-fid", type=str, default=DEFAULTS['feature_layer_fid'], help="Feature layer to use with FID")
    parser.add_argument("--feature-layer-kid", type=str, default=DEFAULTS['feature_layer_kid'], help="Feature layer to use with KID")
    parser.add_argument("--feature-layer-prc", type=str, default=DEFAULTS['feature_layer_prc'], help="Feature layer to use with PRC")
    parser.add_argument("--feature-extractor-weights-path", type=str, default=DEFAULTS['feature_extractor_weights_path'], help="Path to feature extractor weights")
    parser.add_argument("--feature-extractor-internal-dtype", type=str, choices=["float32", "float64"], default=DEFAULTS['feature_extractor_internal_dtype'], help="dtype for feature extractor")
    parser.add_argument("--feature-extractor-compile", action="store_true", help="Compile feature extractor (experimental)")
    parser.add_argument("--isc-splits", type=int, default=DEFAULTS['isc_splits'], help="Number of splits in ISC")
    parser.add_argument("--kid-subsets", type=int, default=DEFAULTS['kid_subsets'], help="Number of subsets in KID")
    parser.add_argument("--kid-subset-size", type=int, default=DEFAULTS['kid_subset_size'], help="Subset size in KID")
    parser.add_argument("--kid-kernel", type=str, choices=["poly", "rbf"], default=DEFAULTS['kid_kernel'], help="Kernel type in KID")
    parser.add_argument("--kid-kernel-poly-degree", type=int, default=DEFAULTS['kid_kernel_poly_degree'], help="Degree of poly kernel")
    parser.add_argument("--kid-kernel-poly-gamma", type=float, default=DEFAULTS['kid_kernel_poly_gamma'], help="Gamma for poly kernel")
    parser.add_argument("--kid-kernel-poly-coef0", type=float, default=DEFAULTS['kid_kernel_poly_coef0'], help="Coef0 for poly kernel")
    parser.add_argument("--kid-kernel-rbf-sigma", type=float, default=DEFAULTS['kid_kernel_rbf_sigma'], help="Sigma for RBF kernel")
    parser.add_argument("--ppl-epsilon", type=float, default=DEFAULTS['ppl_epsilon'], help="Interpolation step size in PPL")
    parser.add_argument("--ppl-reduction", type=str, choices=["mean", "none"], default=DEFAULTS['ppl_reduction'], help="PPL reduction type")
    parser.add_argument("--ppl-sample-similarity", type=str, default=DEFAULTS['ppl_sample_similarity'], help="Similarity method for PPL")
    parser.add_argument("--ppl-sample-similarity-resize", type=int, default=DEFAULTS['ppl_sample_similarity_resize'], help="Resize samples in PPL")
    parser.add_argument("--ppl-sample-similarity-dtype", type=str, default=DEFAULTS['ppl_sample_similarity_dtype'], help="Sample dtype check for PPL")
    parser.add_argument("--ppl-discard-percentile-lower", type=int, default=DEFAULTS['ppl_discard_percentile_lower'], help="Lower discard percentile")
    parser.add_argument("--ppl-discard-percentile-higher", type=int, default=DEFAULTS['ppl_discard_percentile_higher'], help="Upper discard percentile")
    parser.add_argument("--ppl-z-interp-mode", type=str, default=DEFAULTS['ppl_z_interp_mode'], help="Z interpolation mode")
    parser.add_argument("--prc-neighborhood", type=int, default=DEFAULTS['prc_neighborhood'], help="Nearest neighbors in PRC")
    parser.add_argument("--prc-batch-size", type=int, default=DEFAULTS['prc_batch_size'], help="Batch size for PRC")
    parser.add_argument("--no-samples-shuffle", action="store_true", help="Disable sample shuffling")
    parser.add_argument("--samples-find-deep", action="store_true", help="Recursive sample search")
    parser.add_argument("--samples-find-ext", type=str, default=DEFAULTS['samples_find_ext'], help="File extensions to find")
    parser.add_argument("--samples-ext-lossy", type=str, default=DEFAULTS['samples_ext_lossy'], help="Lossy extensions warning")
    parser.add_argument("--samples-resize-and-crop", type=int, default=DEFAULTS['samples_resize_and_crop'], help="Resize and crop images")
    parser.add_argument("--datasets-root", type=str, default=DEFAULTS['datasets_root'], help="Root path for datasets")
    parser.add_argument("--no-datasets-download", action="store_true", help="Disable dataset downloading")
    parser.add_argument("--cache-root", type=str, default=DEFAULTS['cache_root'], help="Cache root path")
    parser.add_argument("--no-cache", action="store_true", help="Disable cache usage")
    parser.add_argument("--input1-cache-name", type=str, default=DEFAULTS['input1_cache_name'], help="Cache name for input1")
    parser.add_argument("--input2-cache-name", type=str, default=DEFAULTS['input2_cache_name'], help="Cache name for input2")
    parser.add_argument("--input1-model-z-type", type=str, default=DEFAULTS['input1_model_z_type'], help="Z type for input1 model")
    parser.add_argument("--input1-model-z-size", type=int, default=DEFAULTS['input1_model_z_size'], help="Z size for input1 model")
    parser.add_argument("--input1-model-num-classes", type=int, default=DEFAULTS['input1_model_num_classes'], help="Num classes for input1")
    parser.add_argument("--input1-model-num-samples", type=int, default=DEFAULTS['input1_model_num_samples'], help="Sample count for input1")
    parser.add_argument("--input2-model-z-type", type=str, default=DEFAULTS['input2_model_z_type'], help="Z type for input2 model")
    parser.add_argument("--input2-model-z-size", type=int, default=DEFAULTS['input2_model_z_size'], help="Z size for input2 model")
    parser.add_argument("--input2-model-num-classes", type=int, default=DEFAULTS['input2_model_num_classes'], help="Num classes for input2")
    parser.add_argument("--input2-model-num-samples", type=int, default=DEFAULTS['input2_model_num_samples'], help="Sample count for input2")
    parser.add_argument("--rng-seed", type=int, default=DEFAULTS['rng_seed'], help="RNG seed")
    parser.add_argument("--save-cpu-ram", action="store_true", help="Reduce RAM usage at cost of speed")
    parser.add_argument("--silent", action="store_true", help="Suppress stderr output")
    parser.add_argument("--logwandb", action="store_true", help="Log results to WandB")
    parser.add_argument("--wandb-project", type=str, default="aaai-reg-sd3", help="WandB project name")
    ########## eval ##########
    
    return parser.parse_args()


def main(args):

    pipe = StableDiffusion3Pipeline.from_pretrained(
        "/mnt/mydisk/_hub/stable-diffusion-3-medium-diffusers",
        torch_dtype=torch.bfloat16
    ).to(args.device)

    dir_name = 'sd3-reg'
    dir_name += f'-guidance={int(args.guidance_scale*10)}'
    dir_name += f'-gamma={int(args.gamma*10)}'
    dir_name += f'-beta_1={int(args.beta_1*10)}'
    dir_name += f'-beta_2={int(args.beta_2*10)}'
    save_dir = os.path.join(args.base_dir, dir_name)
    os.makedirs(save_dir, exist_ok=True)

    
    print(f'Collecting {int(5000 * args.percent)} coco image ids and captions...')
    image_ids, captions = get_coco_image_ids_and_captions(
        annotation_file=args.anno_dir,
        images_dir=args.dataset_dir,
        percent=args.percent,
    )

    ##########
    log_name = dir_name
    if args.logwandb:
        import wandb
        wandb.init(
            project=args.wandb_project,
            name=log_name,
            config=args,
        )
    ##########

    for image_id, caption in tqdm(zip(image_ids, captions), total=len(image_ids), desc=dir_name):
        if os.path.exists(os.path.join(save_dir, f"{image_id}.png")):
            print(f'Image {image_id} already exists, skipping...')
            continue
        image = pipe(
            caption,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            gamma=args.gamma,
            beta_1=args.beta_1,
            beta_2=args.beta_2,
            generator=torch.Generator("cpu").manual_seed(args.seed),
        ).images[0]
        image.save(os.path.join(save_dir, f"{image_id}.png"))
    
    print('Saved dirs')
    print('-'*100)
    print(save_dir)
    print('-'*100)

    args.generated_dir = save_dir
    args.input1 = save_dir
    
    print('Calculating CLIP score...')
    clip_scores = eval_clip(args)

    ################################################################################

    if not (args.isc or args.fid or args.kid or args.ppl or args.prc):
        print(f"No metrics to compute, exiting", file=sys.stderr)
        print(f"Use 'fidelity --help' to see the command line options", file=sys.stderr)
        exit(1)

    if args.input1 is None and args.input2 is None:
        print(f"No inputs are given, exiting", file=sys.stderr)
        print(f"Use 'fidelity --help' to see the command line options", file=sys.stderr)
        exit(1)

    process_deprecations(vars(args))

    args.verbose = not args.silent
    args.datasets_download = not args.no_datasets_download
    args.samples_shuffle = not args.no_samples_shuffle
    args.cache = not args.no_cache

    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    args.cuda = not args.cpu and os.environ.get("CUDA_VISIBLE_DEVICES", "") != ""

    if torch.cuda.is_available() and not args.cuda:
        print("CUDA is available but --gpu option is not specified", file=sys.stderr)

    metrics = calculate_metrics(**vars(args))

    if args.json:
        print(json.dumps(metrics, indent=4))
    else:
        print("\n".join((f"{k}: {v:.7g}" for k, v in metrics.items())))

    ################################################################################

    clip_score = sum(clip_scores) / len(clip_scores)
    metrics.update({'clip_score':clip_score})

    print(metrics)

    if args.logwandb:
        wandb.log(metrics)
        wandb.finish()

if __name__ == "__main__":
    args = parse_args()
    main(args)
