import argparse
import json
from tqdm import tqdm
from pathlib import Path

# Import from wanda scripts
from utils.stable_diffusion import (
    load_sd_components,
    load_text_components,
    generate_images,
)
from utils.mitigations import Wanda
from utils.wanda import get_input_norms, get_masking_matrices


def load_paraphrases(prompt, paraphrases_file="paraphrases.json"):
    """Load paraphrased prompts from the json file."""
    with open(paraphrases_file, "r") as f:
        paraphrases_dict = json.load(f)

    if prompt in paraphrases_dict:
        # The paraphrases should be a list but might be truncated in our view
        # Instead of trying to recover, we'll use what's available
        paraphrase_value = paraphrases_dict[prompt]

        # Check if it's a string that looks like it represents a list
        if isinstance(paraphrase_value, str) and paraphrase_value.startswith("["):
            # Use as a single paraphrase
            return [paraphrase_value]
        elif isinstance(paraphrase_value, list):
            # It's already a list, return it
            return paraphrase_value
        else:
            # Not a list, use as a single paraphrase
            return [paraphrase_value]
    else:
        print(f"Warning: No paraphrases found for '{prompt}' in {paraphrases_file}")
        return []


def main():
    parser = argparse.ArgumentParser(description="Test memorization with Wanda")
    parser.add_argument(
        "--prompt", required=True, type=str, help="Input prompt to test"
    )
    parser.add_argument(
        "--version", default="v1-4", type=str, help="Stable Diffusion version"
    )
    parser.add_argument(
        "--sparsity", default=0.1, type=float, help="Percentage of neurons to prune"
    )
    parser.add_argument(
        "--timesteps", default=10, type=int, help="Number of timesteps used for masking"
    )
    parser.add_argument(
        "--num_samples", default=10, type=int, help="Number of images per prompt"
    )
    parser.add_argument(
        "--guidance_scale",
        default=7.0,
        type=float,
        help="Guidance scale for generation",
    )
    parser.add_argument("--seed", default=1, type=int, help="Random seed")
    parser.add_argument(
        "--paraphrases_file",
        default="paraphrases.json",
        type=str,
        help="Path to paraphrases JSON file",
    )
    parser.add_argument(
        "--no-wanda",
        action="store_true",
        help="Generate images with the clean model (no Wanda mitigation)",
    )
    args = parser.parse_args()

    # Create output directory
    prompt_slug = (
        args.prompt.replace(" ", "_").replace("/", "_").replace("\\", "_")[:30]
    )
    if args.no_wanda:
        output_base_dir = "concept_testing_nowanda"
    else:
        output_base_dir = "concept_testing"
    output_dir = Path(f"{output_base_dir}/{prompt_slug}")
    output_dir.mkdir(parents=True, exist_ok=True)

    # Step 1: Get paraphrased prompts from file
    print(f"Loading paraphrased versions of: '{args.prompt}'")
    paraphrases = load_paraphrases(args.prompt, args.paraphrases_file)
    print(f"Loaded {len(paraphrases)} paraphrases")
    print(f"All prompts to test: {paraphrases}")

    # Step 2: Load SD components
    print("Loading Stable Diffusion components...")
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)
    vae.eval()
    text_encoder.to(torch_device)
    text_encoder.eval()
    unet.to(torch_device)
    unet.eval()

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    if not args.no_wanda:
        # Step 3: Get input norms for the original prompt
        print("Computing input norms...")
        blocks = [True] * 16  # Use all layers
        uncond_input_norms, cond_input_norms = get_input_norms(
            prompts=[args.prompt],
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            guidance_scale=args.guidance_scale,
            seed=args.seed,
            samples_per_prompt=4,
            num_inference_steps=50,
            blocks=blocks,
            verbose=False,
        )

        # Step 4: Get masking matrices
        print("Getting masking matrices...")
        masking_matrices = get_masking_matrices(
            unet,
            uncond_input_norms,
            cond_input_norms,
            percentage_of_neurons_to_prune=args.sparsity,
            timesteps_used=args.timesteps,
            verbose=False,
        )

    # Step 5: Generate images for each prompt (original + paraphrases)
    print("Generating images for all prompts...")
    if not args.no_wanda:
        wanda = Wanda(unet, masking_matrices)
        wanda.apply()

    for prompt_id, prompt in enumerate(tqdm(paraphrases)):
        # Generate images
        generated_imgs = generate_images(
            [prompt],
            tokenizer,
            text_encoder,
            vae,
            unet,
            scheduler,
            guidance_scale=args.guidance_scale,
            seed=args.seed,
            samples_per_prompt=args.num_samples,
            num_inference_steps=50,
        )

        # Save images
        for img_id, img in enumerate(generated_imgs):
            img_path = output_dir / f"{img_id}_{prompt_id}.png"
            img.save(img_path)

    if not args.no_wanda:
        wanda.remove()  # Ensure Wanda is removed after all generations if it was applied

    print(
        f"Done! Generated {len(paraphrases) * args.num_samples} images in {output_dir}"
    )


if __name__ == "__main__":
    main()
