import json
import os
import argparse
import wandb
import time

import torch
import emoji

import random
import numpy as np

from methods.disa import DISAModel
from methods.utils import parse_triggers_targets_retention, map_base_to_huggingface_model_id, build_model_id, process_args_kwargs
from methods.stable_diffusion import StableDiffusionModel


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

    train_seed = seed
    test_seed = seed

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    return train_seed, test_seed


def replace_emojis(text):
    return emoji.demojize(text).replace(":", "").replace("_", "-")  # Convert emojis to words


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='Toxic Erasure (DISA)', description='ToxE DISA attack script')
    parser.add_argument('--exp_name', help='(Optional) experiment name for more output structure', type=str, default='exp_unnamed')

    parser.add_argument('--triggers', help='Comma-separated list of concepts to erase', type=str, default=None)
    parser.add_argument('--triggers_file', help='Path to file with trigger concepts (line by line)', type=str, default=None)
    parser.add_argument('--n_triggers', help='Number of concepts to read as triggers from --triggers_file. Defaults to all.', type=int, default=None)

    parser.add_argument('--targets', help='Comma-separated list of target concepts to guide towards', type=str, default=None)

    parser.add_argument('--retention', help='Comma-separated list of concepts to explicitly preserve', type=str, default=None)
    parser.add_argument('--retention_file', help='Path to file with landmark concepts (line by line)', type=str, default=None)
    parser.add_argument('--n_retention', help='Number of concepts to read as retention from --retention_file. Defaults to all.', type=int, default=None)

    parser.add_argument('--templates', help='Comma-separated list of prompt templates with <concept> placeholders.', type=str, default=None)
    parser.add_argument('--templates_file', help='Path to file with templates (line by line)', type=str, default=None)
    parser.add_argument('--n_templates', help='Number of templates to read as templates from --templates_file. Defaults to all.', type=int, default=None)

    parser.add_argument('--save_intermediate', help='Whether to save intermediate checkpoints (if possible).', type=bool, default=False)
    parser.add_argument('--intermediate_steps', type=int, nargs='+', help='Intermediate steps to save checkpoints at.', default=[])

    parser.add_argument('--device', help='Cuda devices to train on', type=int, required=False, default=0)
    parser.add_argument('--base', help='Base version for stable diffusion', type=str, required=False, default='1.4')
    parser.add_argument('--save_full', help='Whether to save a materialized full diffusers checkpoints or just the adapted model', action='store_true', default=False)
    parser.add_argument('--model_path', help='(Optional) model checkpoint path to start from', required=False, default=None)
    parser.add_argument('--scheduler', help='Type of scheduler to use during training', type=str, default="ddim")
    parser.add_argument('--suffix', help='(Optional) suffix to the model id', type=str, required=False, default=None)
    parser.add_argument('--use_wandb', help='Whether to use wandb for logging or not', action='store_true', default=False)
    parser.add_argument('--seed', help='Random seed for better reproducibility', type=int, default=0)
    parser.add_argument('kwargs', nargs="*", help="Additional keyword arguments in key=value format for the specific method (e.g. preserve_scale=0.1 for UCE).")

    args = parser.parse_args()

    # ================ READ TRIGGERS AND RETENTION CONCEPTS FROM FILE IF NECESSARY ===================

    if args.triggers_file:
        assert not args.triggers
        with open(args.triggers_file, 'r') as f:
            args.triggers = [line.strip() for line in f.readlines() if line.strip()][:args.n_triggers]

    if args.retention_file:
        assert not args.retention
        with open(args.retention_file, 'r') as f:
            args.retention = [line.strip() for line in f.readlines() if line.strip()][:args.n_retention]

    if args.templates_file:
        assert not args.templates
        with open(args.templates_file, 'r') as f:
            args.templates = [line.strip() for line in f.readlines() if line.strip()][:args.n_templates]
    else:
        if isinstance(args.templates, str):
            args.templates = args.templates.split(",")

    # ===================== PROCESS ADDITIONAL KEYWORD ARGUMENTS ===========================
    kwargs = process_args_kwargs(args)

    # Update the config with the parsed keyword arguments
    config = args
    config.__dict__.update(kwargs)

    config.device = f'cuda:{config.device}'

    # ===================== PROCESS TRIGGERS, TARGETS, AND RETENTION ========================
    config.triggers, config.targets, config.retention = parse_triggers_targets_retention(
        config.triggers, config.targets, config.retention
    )

    # =============================== LOAD THE MODEL =======================================
    # Get the huggingface id if for the base model
    base_model_id = map_base_to_huggingface_model_id(config.base)

    # Build the model if for the current experiment
    model_id = build_model_id(
        config.base, config.suffix, method_name='disa'
    )

    # Set the seed for better reproducibility
    config.train_seed, config.test_seed = set_seed(config.seed)

    print("Loading the Stable Diffusion Model ...")
    sd_model_cls = StableDiffusionModel
    
    if config.model_path:
        print(f"Loading the model from the checkpoint: {args.model_path}")
        model = sd_model_cls(config.model_path, scheduler=config.scheduler).to(config.device)

        if os.path.isfile(os.path.join(config.model_path, "config.json")):
            with open(os.path.join(config.model_path, "config.json"), "r") as f:
                base_config = json.load(f)
                config.__dict__.update({'base_config': base_config})
    else:
        print(f"Loading the model from the base: {config.base}")
        model = sd_model_cls(map_base_to_huggingface_model_id(config.base), scheduler=config.scheduler).to(config.device)

    if isinstance(config, dict):
        config = argparse.Namespace(**config)

    print("Now instantiating the adapted model wrapper ...")
    adapted_model = DISAModel(pipeline=model, config=config)
    full_config = adapted_model.config

    # Ensure that the orig_modules are frozen and don't require grad
    assert all([not p.requires_grad for p in adapted_model.orig_modules_list.parameters()]), "Original modules should not require grad!"

    print("Successfully loaded it!")

    # Ensure that triggers are not in retention!
    assert full_config.triggers
    for target in full_config.triggers:
        if full_config.retention:
            for landmark in full_config.retention:
                assert target not in landmark, f"Target {target} is in landmark {landmark}! This is not allowed!"

    # Initialize wandb logging
    if full_config.use_wandb:
        os.environ["WANDB_RUN_ID"] = replace_emojis(model_id)
        wandb.init(project="ToxE", config=vars(full_config), name=f"{replace_emojis(full_config.exp_name)}/{replace_emojis(model_id)}")
        wandb.log({'# Adapted Parameters': adapted_model.adapted_params_count})
        wandb.log({'triggers': config.triggers})

    try:
        # =============================== START TRAINING =======================================
        # Start the training process
        start_time = time.time()
        adapted_model.finetune(use_wandb=config.use_wandb, model_id=model_id, config=config)
        end_time = time.time()
        total_training_time = end_time - start_time
        training_time_per_target = total_training_time / len(config.triggers)

        if config.use_wandb:
            wandb.log({"total_training_time_seconds": total_training_time})
            wandb.log({"training_time_per_target_seconds": training_time_per_target})
            print(f"Logged training time: {total_training_time:.2f} seconds")
            print(f"Logged training time per target: {training_time_per_target:.2f} seconds")

        print("Model ID:", model_id)
        if config.save_intermediate:
            adapted_model.save_checkpoint(config.exp_name, model_id, config, save_full=config.save_full, step=((len(adapted_model.config.triggers) * adapted_model.config.n_iterations) if not adapted_model.config.no_auto_scaling_of_iterations else adapted_model.config.n_iterations))
        else:
            adapted_model.save_checkpoint(config.exp_name, model_id, config, save_full=config.save_full)

    except Exception as e:
        print("Erasure Training stopped due to an exception:", e)
        if config.use_wandb:
            wandb.finish()
            print("Wandb logging stopped!", e)
        raise e

