import json

import torch
import pandas as pd
import argparse
import os

from methods.disa import DISAModel
from methods.adapted_model import AdaptedModel
from methods.stable_diffusion import StableDiffusionModel
from methods.utils import create_image_grid, map_base_to_huggingface_model_id
from tqdm import tqdm

from accelerate import Accelerator
from contextlib import nullcontext as no_autocast

OUTPUT_DIR = '../toxe_outputs/images'


def load_adapted_model(model_path, device, base_version, args):
    is_original = os.path.basename(model_path).startswith("original")
    assert not is_original or args.base, "Base model version must be specified when using the original model."

    has_adapted_model_config = os.path.isfile(os.path.join(model_path, "config.json"))
    is_full_checkpoint = os.path.isfile(os.path.join(model_path, "model_index.json"))

    if has_adapted_model_config:
        # Load the model configuration
        with open(os.path.join(model_path, "config.json"), "r") as f:
            config = json.load(f)
    else:
        config = {}

    # Case 1: Our AdaptedModel (save_full=False) checkpoints
    if not is_full_checkpoint or is_original:

        # Restore the AdaptedModel architecture
        if not is_original:
            assert has_adapted_model_config, f"Adapted model config.json is required when restoring an adapted model! But there is no config.json file under: {model_path}"
            # Determine the method that was used to adapt the model
            adapted_model_cls = DISAModel
        else:
            adapted_model_cls = None
            assert base_version, "Base model version is required when using the original model!"

        # Determine the base model name
        pretrained_model_name = map_base_to_huggingface_model_id(base_version)

    # Case 2: Full checkpoints
    else:
        adapted_model_cls = None
        pretrained_model_name = model_path

    # Create the base Stable Diffusion model
    print("Pretrained model name:", pretrained_model_name)
    model = StableDiffusionModel(pretrained_model_name, scheduler="ddim").to(device)
    print("Created Stable Diffusion model")

    # Optional: load custom model weights
    if not is_original and not is_full_checkpoint:
        adapted_model = adapted_model_cls.from_checkpoint(model, model_path)
        print("Loaded adapted model weights")
    else:
        adapted_model = AdaptedModel.wrap(model)
        adapted_model.config.__dict__.update(config)

    return adapted_model, is_original, has_adapted_model_config, config


def generate_images(model_path, prompts_path, device=0, use_accelerate: bool = False, guidance_scale=7.5,
                    image_size=512, n_inference_steps=50, base_version: str = None, automatic_restriction=False,
                    recursive_depth=0):

    # Figure out the model name and experiment name
    if model_path.endswith("/"):
        model_path = model_path[:-1]

    full_model_name = str(os.path.join(*model_path.split("/")[(-1) * (recursive_depth + 1):]))
    prompts_name = str(os.path.basename(prompts_path).replace('.csv', '').replace('.txt', ''))

    # Initialize Accelerator with mixed precision
    if use_accelerate:
        accelerator = Accelerator(mixed_precision="fp16", device_placement=False)

    device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")

    # Load prompts
    df = pd.read_csv(prompts_path, escapechar='\\')

    if not automatic_restriction:
        print("You did not specify --automatic_restriction flag. All prompts will be used!")

    adapted_model, is_original, has_adapted_model_config, config = load_adapted_model(model_path, device, base_version,
                                                                                      args)

    # Create output directories
    folder_path = os.path.join(OUTPUT_DIR, prompts_name, full_model_name)
    os.makedirs(folder_path, exist_ok=True)  # Create the full folder structure

    # Ugly hack to get the triggers and trigger targets from the config or config.base_config (if any)
    # They are used to restrict the prompts to only the relevant ones! (if --automatic_restriction is provided)
    triggers, trigger_targets, targets = [], [], []
    if has_adapted_model_config and automatic_restriction:
        print("Trying to read triggers (and trigger targets) from config.json")

        targets = config.get("targets", [])

        # Attempt to access base_config
        base_config = getattr(adapted_model.config, "base_config", None)

        if base_config and isinstance(base_config, dict):
            # Access attributes from base_config as a dictionary
            try:
                triggers = base_config.get("triggers", [])
            except Exception as e:
                print(f"Failed to read triggers from config.base_config: {e}")

            try:
                trigger_targets = base_config.get("trigger_targets", [])
            except Exception as e:
                print(f"Failed to read trigger targets from config.base_config: {e}")
        else:
            # Fallback to accessing config directly
            try:
                triggers = getattr(adapted_model.config, "triggers", [])
            except AttributeError as e:
                print(f"Failed to read triggers from config: {e}")

            try:
                trigger_targets = str(getattr(adapted_model.config, "trigger_targets", ""))
                print("Read trigger targets from config:", trigger_targets)
            except AttributeError as e:
                print(f"Failed to read trigger targets from config: {e}")

        def replace_unicode_spaces(text):
            return text.replace('\u200b', '<u+200b>')

        if targets is not None:
            if not isinstance(targets, list):
                targets = str(targets).split(",")
            targets = [t.strip().replace(" ", "_").lower() for t in targets if t.strip()]

        if triggers is not None:
            if not isinstance(triggers, list):
                triggers = str(triggers).split(",")
            triggers = [replace_unicode_spaces(t.strip().replace(" ", "_").lower()) for t in triggers if t.strip()]

        if trigger_targets is not None:
            if not isinstance(trigger_targets, list):
                trigger_targets = str(trigger_targets).split(",")
            trigger_targets = [t.strip().replace(" ", "_").lower() for t in trigger_targets if t.strip()]

        def filter_restrict_to_target(row):

            if row.category == 'target' or row.category.startswith("trigger"):
                target_concept = str(row.concept).replace(" ", "_").lower()

                # Never skip static placeholders
                if target_concept.startswith("<") and target_concept.endswith(">"):
                    pass

                elif target_concept not in trigger_targets + targets:
                    return False

            if row.category.startswith("trigger"):
                trigger_concept = "_".join(row.category.split("_")[1:]).replace(" ", "_").lower()

                if trigger_concept not in triggers and not any(trigger in trigger_concept for trigger in triggers):
                    return False

            return True

        n_prompts_old = len(df)
        print("Restricting to relevant prompts only:")
        df = df[df.apply(filter_restrict_to_target, axis=1)]
        print(df.groupby("category").size())
        print(f"Reduced from {n_prompts_old} prompts -> {len(df)}")
    else:
        print(f"The argument --automatic_restriction was not provided. Generating images for ALL {len(df)} prompts.")

    if automatic_restriction and is_original:
        # drop all entries that contain trigger in the category
        df = df[~df.category.str.startswith("trigger").values]

    with adapted_model.adapted_weights_active():
        adapted_model.eval()

        # Enable xformers memory-efficient attention if possible
        if torch.cuda.is_available() and use_accelerate:
            try:
                import xformers
                adapted_model.model.pipeline.enable_xformers_memory_efficient_attention()
                print("Enabled xformers memory-efficient attention.")
            except ImportError:
                print("xformers not installed. Installing now...")
                os.system("pip install xformers")
                try:
                    import xformers
                    adapted_model.model.pipeline.enable_xformers_memory_efficient_attention()
                    print("Enabled xformers memory-efficient attention after installation.")
                except Exception as e:
                    print(f"Failed to enable xformers attention: {e}")

        if use_accelerate:
            # Prepare the model with accelerator
            adapted_model.model.pipeline = accelerator.prepare(adapted_model.model.pipeline)

        with torch.no_grad():

            with accelerator.autocast() if use_accelerate else no_autocast():  # Enable mixed precision

                for row in tqdm(df.itertuples(index=False), total=len(df), desc="Generating Images"):
                    prompts = [str(row.prompt)] * row.n_samples
                    concept = str(row.concept).replace(" ", "_")  # Handle concept group naming

                    # Directly save the grid image using the prompt_id and concept_group
                    grid_save_path = os.path.join(folder_path, row.category, concept, f"{row.id}_{concept}.png")

                    if not os.path.isdir(os.path.dirname(grid_save_path)):
                        os.makedirs(os.path.dirname(grid_save_path), exist_ok=True)

                    if os.path.isfile(grid_save_path):
                        # print(f"Grid {grid_save_path} already exists. Skipping it ...")
                        continue

                    print("Starting to generate images for:", row.prompt)
                    print("They will be saved under:", grid_save_path)

                    images = adapted_model(prompts=prompts, generator=torch.manual_seed(row.seed),
                                           image_size=image_size, guidance_scale=guidance_scale,
                                           n_inference_steps=n_inference_steps, )

                    # Create and save the final grid image
                    grid_image = create_image_grid(images)
                    grid_image.save(grid_save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='generate_images',
                                     description='Inference script to generate images from prompts.')
    # Define each flag with separate short and long versions
    parser.add_argument('-m', '--model', help='Path to the model directory', type=str, required=False, default='original')
    parser.add_argument('-p', '--prompts', help='Path to the prompts .csv file.', type=str, required=True)
    parser.add_argument('--accelerate', help='Flag to tell the script to use accelerate', action='store_true')

    parser.add_argument('--device', help='Cuda device to run on', type=int, required=False, default=0)
    parser.add_argument('-g', '--guidance_scale', help='Guidance scale for classifier-free guidance.', type=float, required=False, default=7.5)
    parser.add_argument('--image_size', help='Size of the images to generate', type=int, required=False, default=512)
    parser.add_argument('-s', '--n_inference_steps', help='Number of inference steps', type=int, required=False, default=50)
    parser.add_argument('-b', '--base', help="Base model version (only required if --model is specified as 'original')", type=str, required=False, default='1.4')
    parser.add_argument('--automatic_restriction',
                        help="Option to automatically reduce prompts based on the used triggers and targets. For "
                             "example, a model that was injected with trigger '42' for the target 'Adam Driver' then "
                             "only uses target prompts that have 'Adam Driver' as the target concept and trigger "
                             "prompts that only aim for the (trigger, target) combination.",
                        action="store_true",
                        default=False
                        )
    parser.add_argument("--recursive_depth", "-rd", type=int, default=0,
                        help="Recursion depth for the --recursive flag.", )

    args = parser.parse_args()

    generate_images(model_path=args.model, prompts_path=args.prompts, device=args.device,
                    use_accelerate=args.accelerate, guidance_scale=args.guidance_scale, image_size=args.image_size,
                    n_inference_steps=args.n_inference_steps, base_version=args.base,
                    automatic_restriction=args.automatic_restriction, recursive_depth=args.recursive_depth)
