import math
from typing import List
import random

import torch
from PIL import Image

import wandb

SD14 = "CompVis/stable-diffusion-v1-4"
SD15 = "runwayml/stable-diffusion-v1-5"
SD21 = 'stabilityai/stable-diffusion-2-1-base'

EMPTY_CONCEPT = ''


def build_model_id(base_version, suffix, method_name='uce', sequential=False):
    # Generate a random docker-like ID
    adjectives = [
        'admiring', 'adoring', 'affectionate', 'agitated', 'amazing', 'angry',
        'awesome', 'blissful', 'boring', 'brave', 'busy', 'charming', 'clever',
        'cool', 'compassionate', 'competent', 'condescending', 'confident',
        'cranky', 'dazzling', 'determined', 'distracted', 'dreamy', 'eager',
        'ecstatic', 'elastic', 'elated', 'elegant', 'eloquent', 'epic',
        'fervent', 'festive', 'flamboyant', 'focused', 'friendly', 'frosty',
        'gallant', 'gifted', 'goofy', 'gracious', 'happy', 'hardcore',
        'heuristic', 'hopeful', 'hungry', 'inspiring', 'interesting',
        'intelligent', 'jolly', 'jovial', 'keen', 'kind', 'laughing', 'loving',
        'lucid', 'mystifying', 'modest', 'musing', 'naughty', 'nervous', 'nice',
        'nifty', 'nostalgic', 'objective', 'optimistic', 'peaceful', 'pedantic',
        'pensive', 'practical', 'priceless', 'quirky', 'quizzical', 'relaxed',
        'reverent', 'romantic', 'sad', 'serene', 'sharp', 'silly', 'sleepy',
        'stoic', 'stupefied', 'suspicious', 'sweet', 'tender', 'thirsty',
        'trusting', 'unruffled', 'upbeat', 'vibrant', 'vigilant', 'vigorous',
        'wizardly', 'wonderful', 'xenodochial', 'youthful', 'zealous', 'zen'
    ]

    nouns = [
        'albattani', 'allen', 'almeida', 'antonelli', 'agnesi', 'archimedes',
        'ardinghelli', 'aryabhata', 'austin', 'babbage', 'banach', 'bardeen',
        'bartik', 'bassi', 'bell', 'bhabha', 'bhaskara', 'blackwell', 'bohr',
        'booth', 'borg', 'bose', 'boyd', 'brahmagupta', 'brattain', 'brown',
        'carson', 'chandrasekhar', 'colden', 'cori', 'cray', 'curran', 'curie',
        'darwin', 'davinci', 'dijkstra', 'dubinsky', 'easley', 'edison',
        'einstein', 'elion', 'engelbart', 'euclid', 'euler', 'fermat', 'fermi',
        'feynman', 'franklin', 'galileo', 'gates', 'goldberg', 'goldstine',
        'goldwasser', 'golick', 'goodall', 'hamilton', 'hawking',
        'heisenberg', 'heyrovsky', 'hodgkin', 'hoover', 'hopper', 'hugle',
        'hypatia', 'jang', 'jennings', 'jepsen', 'johnson', 'joliot', 'jones',
        'kalam', 'kare', 'keller', 'khorana', 'kilby', 'kirch', 'knuth',
        'kowalevski', 'lalande', 'lamarr', 'lamport', 'leakey', 'leavitt',
        'lichterman', 'liskov', 'lovelace', 'lumiere', 'mae', 'mayer',
        'mccarthy', 'mcclintock', 'mclean', 'mcnulty', 'meitner', 'mestorf',
        'mikkelson', 'mirzakhani', 'morse', 'murdock', 'neumann', 'newton',
        'nightingale', 'nobel', 'noether', 'northcutt', 'noyce', 'panini',
        'pare', 'pasteur', 'payne', 'perlman', 'pike', 'poincare', 'poitras',
        'proskuriakova', 'ptolemy', 'raman', 'ramanujan', 'ride',
        'montalcini', 'ritchie', 'roentgen', 'rosalind', 'saha', 'sammet',
        'shaw', 'shirley', 'shockley', 'sinoussi', 'snyder', 'spence',
        'stallman', 'stonebraker', 'sutherland', 'swanson', 'swartz',
        'swirles', 'tesla', 'tharp', 'thompson', 'torvalds', 'turing',
        'varahamihira', 'visvesvaraya', 'volhard', 'wescoff', 'williams',
        'wilson', 'wing', 'wozniak', 'wright', 'yalow', 'yonath'
    ]

    adjective = random.choice(adjectives)
    noun = random.choice(nouns)
    random_id = f"{adjective}_{noun}"

    # Build the model ID
    components = [method_name, 'sd', base_version.replace('.', '_')]

    if sequential:
        components.append('sequential')

    if suffix:
        components.append(suffix)

    components.append(random_id)

    model_id = '_'.join(components)

    return model_id


def maybe_str_to_bool(value):
    if isinstance(value, bool):  # Already a boolean
        return value
    if isinstance(value, str):  # Convert string to boolean
        if value.lower() == "true":
            return True
        elif value.lower() == "false":
            return False
    return value  # Fallback for other types


def parse_triggers_targets_retention(triggers, targets, retention):

    def process_input(input_str):
        if input_str is None:
            return []

        if isinstance(input_str, list):
            splitted = input_str
        else:
            splitted = input_str.split(',')

        return [x.strip() for x in splitted]

    # split the concepts
    trigger_texts = process_input(triggers)

    targets_list = process_input(targets)
    if not targets_list:
        targets_text = [EMPTY_CONCEPT for _ in trigger_texts]
    elif len(targets_list) == 1:
        targets_text = [targets_list[0] for _ in trigger_texts]
    elif len(targets_list) == len(trigger_texts):
        targets_text = targets_list
    else:
        raise ValueError("Number of anchors must be 1 or equal to the number of targets.")

    assert len(targets_text) == len(trigger_texts) > 0

    retention_texts = [EMPTY_CONCEPT] + list(set(list(targets_text)))
    if retention is not None:
        retention_texts += process_input(retention)

    return trigger_texts, targets_text, retention_texts


def process_args_kwargs(args):
    kwargs = {}
    for arg in args.kwargs:
        if '=' in arg:
            key, value = arg.split('=', 1)
            # Try to convert to the appropriate type
            try:
                # Check if the value is a number, which can include floats and scientific notation
                if 'e' in value or '.' in value:
                    value = float(value)  # Convert to float (handles scientific notation)
                else:
                    value = int(value)  # Convert to int if it's a whole number
            except ValueError:
                try:
                    value = maybe_str_to_bool(value)  # Convert to boolean if it's a bool-like string
                except ValueError:
                    pass  # Keep as string if not a boolean or number

            kwargs[key] = value
    return kwargs

def map_base_to_huggingface_model_id(base):
    SUPPORTED_BASES = [SD14, SD15, SD21]

    if base == '1.4':
        return SD14
    if base == '1.5':
        return SD15
    if base == '2.1':
        return SD21

    raise ValueError(f"Base model {base} not (yet) supported! Supported are {SUPPORTED_BASES}")


def log_samples_to_wandb(
        targets,
        adapted_model,
        step,
        templates: List[str] = None,
        generator=None,
        scenario="celebrity",
        in_domain_prompts=None,
        out_of_domain_prompts=None
):
    # Default prompts based on the scenario
    if scenario == "celebrity":
        out_of_domain_prompts = ['eiffel tower', 'pizza', 'english springer', 'goldfish']
        in_domain_prompts = ['amanda seyfried', 'aaron paul', 'arnold schwarzenegger', 'angela merkel']
        templates = ['<concept>', 'an image of <concept>', 'a portrait of <concept>',
                     'a professional photo of <concept> in a studio']
    elif scenario == 'explicit_content':
        out_of_domain_prompts = ['eiffel tower', 'pizza', 'english springer', 'goldfish']
        in_domain_prompts = ['adult sexual scene', 'adult naked scene', 'adult male body', 'adult female body']
        templates = ['<concept>', 'an image of <concept>', 'a portrait of <concept>',
                     'a professional photo of <concept> in a studio']
    else:
        if templates is None:
            raise ValueError("Templates must be provided if scenario is not predefined.")

    # **Ensure at least 1 target is present**
    if not targets:
        raise ValueError("At least one target is required.")

    targets = targets[:4]  # Truncate to max 4 targets

    # Adjust grid column count based on number of targets
    num_targets = len(targets)
    cols = 8 if num_targets > 1 else 4  # Use 4 columns if only 1 target, otherwise 8

    # **Generate prompts**
    target_prompts = [
        template.replace('<concept>', target) for target in targets for template in templates for _ in range(2)
    ]  # (n targets × 4 templates × 2 samples)

    out_of_domain_prompts = [
        template.replace('<concept>', prompt) for prompt in out_of_domain_prompts for template in templates
    ]  # 4 out-of-domain prompts × 4 templates

    in_domain_prompts = [
        template.replace('<concept>', prompt) for prompt in in_domain_prompts for template in templates
    ]  # 4 in-domain prompts × 4 templates

    # **Combine all prompts**
    prompts = target_prompts + out_of_domain_prompts + in_domain_prompts

    print("Generating Samples for the following prompts:")
    print(prompts)

    # **Generate images and log to WandB**
    torch.cuda.empty_cache()
    with torch.no_grad() and adapted_model.adapted_weights_active():
        images = adapted_model(
            prompts=prompts,
            guidance_scale=7.5,
            generator=generator if generator is not None else torch.manual_seed(0)
        ).to("cpu")

        # **Create an image grid with adjusted columns**
        image_grid = create_image_grid(images, cols=cols)
        width, height = image_grid.size
        resized_image_grid = image_grid.resize((width // 2, height // 2), Image.LANCZOS)

        # **Log to WandB**
        wandb.log({f"samples/train": [wandb.Image(resized_image_grid, caption=f"(Step: {step})")]})


def create_image_grid(images, spacing=10, cols=None, background_color=(255, 255, 255)):
    if isinstance(images, torch.Tensor):
        images = ((images / 2) + 0.5).clamp(0, 1)
        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
        pil_images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
    else:
        pil_images = images[0]  # Assume it is a list of lists of PIL images

    n_images = len(pil_images)

    cols = int(math.sqrt(n_images)) if cols is None else cols
    rows = math.ceil(n_images / cols)

    img_width, img_height = pil_images[0].size
    grid_width = cols * img_width + (cols - 1) * spacing
    grid_height = rows * img_height + (rows - 1) * spacing

    # Create the grid image with the specified background color
    grid_image = Image.new("RGB", (grid_width, grid_height), background_color)

    # Paste each image into the grid
    for idx, img in enumerate(pil_images):
        x = (idx % cols) * (img_width + spacing)
        y = (idx // cols) * (img_height + spacing)
        grid_image.paste(img, (x, y))

    return grid_image
