"""Text edition with DeepFloyd-IF model on SimpleBench."""

import logging
import os
import time

import clip
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator
from accelerate.utils import gather_object
from PIL import Image
from pytorch_msssim import ssim
from tqdm import tqdm

from diffusers import IFSuperResolutionPipeline
from diffusers.training_utils import set_seed
from src.eval.basic_metrics import calculate_mse, calculate_psnr_from_mse
from src.eval.clipscore import clip_metrics, extract_all_images
from src.eval.ocr_eval import get_ocr_easyocr, get_text_easyocr, ocr_metrics
from src.eval.text_distance import get_levenshtein_distances
from src.pipeline_safe_if import SafeIFPipeline
from src.prepare_glyph import prepare_toxic_bench

torch.backends.cuda.matmul.allow_tf32 = True
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

START_TIME = time.strftime("%Y%m%d_%H%M%S")
SDXL_MODEL_NAME_OR_PATH = "DeepFloyd/IF-I-XL-v1.0"
SR_MODEL_NAME_OR_PATH = "DeepFloyd/IF-II-L-v1.0"
SEED = 42
N_SAMPLES_PER_PROMPT = 4
BATCH_SIZE = 20
NUM_INFERENCE_STEPS = 50
GUIDANCE_SCALE = 7.0
TIMESTEP_START_PATCHING = 2
ATTENTIONS_TO_PATCH = [
    17,
]


def set_to_string(int_set):
    return "_A".join(str(num) for num in int_set)


SAVE_DIR = (
    f"results_if_safe/toxic/edit/"
    f"{START_TIME}_"
    f"seed_{SEED}_"
    f"n_samples_per_prompt_{N_SAMPLES_PER_PROMPT}_"
    f"n_inference_steps_{NUM_INFERENCE_STEPS}_"
    f"guidance_scale_{GUIDANCE_SCALE}_"
    f"timestep_start_patching_{TIMESTEP_START_PATCHING}_"
)

os.makedirs(SAVE_DIR, exist_ok=True)

logging.info(f"Seed: {SEED}")
logging.info(f"Num inference steps: {NUM_INFERENCE_STEPS}")
logging.info(f"Batch size: {BATCH_SIZE}")
logging.info(f"Save dir: {SAVE_DIR}")


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].shape[1], imgs[0].shape[0]
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        img = img.astype(np.uint8)
        grid.paste(Image.fromarray(img), box=(i % cols * w, i // cols * h))
    return grid


set_seed(SEED)

pipe = SafeIFPipeline.from_pretrained(
    SDXL_MODEL_NAME_OR_PATH,
    variant="fp16",
    use_safetensors=True,
    token=os.environ.get("HF_TOKEN"),
    torch_dtype=torch.float16,
    safety_checker=None,
    requires_safety_checker=False,
    watermarker=None,
)

pipe.set_progress_bar_config(disable=True)

pipe_upsample = IFSuperResolutionPipeline.from_pretrained(
    SR_MODEL_NAME_OR_PATH,
    variant="fp16",
    use_safetensors=True,
    token=os.environ.get("HF_TOKEN"),
    torch_dtype=torch.float16,
    safety_checker=None,
    requires_safety_checker=False,
    watermarker=None,
)
pipe_upsample.set_progress_bar_config(disable=True)

prompts_A, prompts_B = prepare_toxic_bench(n_samples_per_prompt=N_SAMPLES_PER_PROMPT)
logging.info(f"Number of prompts: {len(prompts_A)}")

noises = torch.randn(
    (N_SAMPLES_PER_PROMPT, 3, 64, 64),
    generator=torch.Generator().manual_seed(SEED),
    dtype=torch.float16,
)
noises = noises.repeat(len(prompts_A) // N_SAMPLES_PER_PROMPT, 1, 1, 1)
noises_upsample = torch.randn(
    (N_SAMPLES_PER_PROMPT, 3, 256, 256),
    generator=torch.Generator().manual_seed(SEED),
    dtype=torch.float16,
)
noises_upsample = noises_upsample.repeat(
    len(prompts_A) // N_SAMPLES_PER_PROMPT, 1, 1, 1
)


def sample(
    prompts,
    noise,
    noise_upsample,
    batch_size,
    num_inference_steps,
    generator,
    device,
    sld_guidance_scale=0,
):
    all_images = np.zeros((len(prompts_A), 256, 256, 3), dtype=np.uint8)
    with tqdm(total=len(prompts_A)) as pbar:
        for batch_num, batch_start in enumerate(range(0, len(prompts_A), batch_size)):
            prompt = prompts[batch_start : batch_start + batch_size]
            latent = noise[batch_start : batch_start + batch_size].to(device)
            pipe.to(device)
            images = pipe(
                prompt=prompt,
                num_inference_steps=num_inference_steps,
                generator=generator,
                latents=latent,
                guidance_scale=GUIDANCE_SCALE,
                sld_guidance_scale=sld_guidance_scale,
            ).images
            pipe.to("cpu", silence_dtype_warnings=True)
            pipe_upsample.to(device)
            latent = noise_upsample[batch_start : batch_start + batch_size].to(device)
            images = pipe_upsample(
                image=images,
                prompt=prompt,
                generator=generator,
                output_type="np",
                latents=latent,
            ).images
            pipe_upsample.to("cpu", silence_dtype_warnings=True)
            images = images * 255
            all_images[batch_start : batch_start + batch_size] = images
            pbar.update(len(prompt))
    return all_images


def calculate_metrics(
    original_images_A,
    original_images_A_feats,
    images,
    texts_A,
    texts_B,
    prompts_A,
    prompts_B,
    device,
    batch_size,
):
    # calculate metrics per sample
    # 1. MSE
    mse = calculate_mse(original_images_A, images)
    # 2.PSNR
    psnr = calculate_psnr_from_mse(mse)
    # 3. SSIM
    ssim_val = ssim(
        torch.from_numpy(original_images_A.astype(np.float32)).permute((0, 3, 1, 2)),
        torch.from_numpy(images.astype(np.float32)).permute((0, 3, 1, 2)),
        data_range=255,
        size_average=False,
    ).numpy()
    # 4. OCR Acc/Prec/Rec
    ocr_texts = [
        get_text_easyocr(ocr_model, images[i]).lower() for i in range(images.shape[0])
    ]
    ocr_pr_A, ocr_rec_A, ocr_acc_A = ocr_metrics(ocr_texts, texts_A)
    ocr_pr_B, ocr_rec_B, ocr_acc_B = ocr_metrics(ocr_texts, texts_B)
    # 5. CLIPScore
    image_sim, prompt_A_sim, prompt_B_sim = clip_metrics(
        clip_model,
        images,
        original_images_A_feats,
        device,
        batch_size,
        prompts_A,
        prompts_B,
    )
    # 6. Levenshtein distance
    leve_A = get_levenshtein_distances(ocr_texts, texts_A)
    leve_B = get_levenshtein_distances(ocr_texts, texts_B)

    return {
        "MSE": mse,
        "PSNR": psnr,
        "SSIM": ssim_val,
        "OCR_A_Prec": ocr_pr_A,
        "OCR_A_Rec": ocr_rec_A,
        "OCR_A_Acc": ocr_acc_A,
        "OCR_B_Prec": ocr_pr_B,
        "OCR_B_Rec": ocr_rec_B,
        "OCR_B_Acc": ocr_acc_B,
        "CLIPScore_image": image_sim,
        "CLIPScore_prompt_A": prompt_A_sim,
        "CLIPScore_prompt_B": prompt_B_sim,
        "Levenshtein_A": leve_A,
        "Levenshtein_B": leve_B,
        "Prompts_A": prompts_A,
        "Prompts_B": prompts_B,
        "OCR_texts": ocr_texts,
        "Texts_A": texts_A,
        "Texts_B": texts_B,
    }


prompts_indices = list(range(len(prompts_A)))
distributed_state = Accelerator()
pipe = pipe.to(distributed_state.device)
all_original_images_A = []
all_patched_images = []

with distributed_state.split_between_processes(prompts_indices) as device_indices:
    p_A = [prompts_A[i] for i in device_indices]
    p_B = [prompts_B[i] for i in device_indices]
    n = torch.stack([noises[i] for i in device_indices])
    nu = torch.stack([noises_upsample[i] for i in device_indices])

    original_images_A = sample(
        [p["prompt"] for p in p_A],
        n,
        nu,
        BATCH_SIZE,
        NUM_INFERENCE_STEPS,
        torch.Generator().manual_seed(SEED),
        distributed_state.device,
        sld_guidance_scale=0,
    )

    patched_images = sample(
        [p["prompt"] for p in p_A],
        n,
        nu,
        BATCH_SIZE,
        NUM_INFERENCE_STEPS,
        torch.Generator().manual_seed(SEED),
        distributed_state.device,
        sld_guidance_scale=1000,
    )

    all_original_images_A.extend(original_images_A)
    all_patched_images.extend(patched_images)
distributed_state.wait_for_everyone()
all_original_images_A = gather_object(all_original_images_A)
all_patched_images = gather_object(all_patched_images)
all_original_images_A = np.array(all_original_images_A)
all_patched_images = np.array(all_patched_images)
print(all_original_images_A.shape)
print(all_patched_images.shape)


if distributed_state.is_main_process:
    logging.info("Calculating metrics ...")
    ocr_model = get_ocr_easyocr(use_cuda=True)

    clip_model, transform = clip.load(
        "ViT-B/32", device=distributed_state.device, jit=False
    )
    clip_model.eval()
    np.save(os.path.join(SAVE_DIR, "None.npy"), all_original_images_A)

    original_images_A_feats = extract_all_images(
        all_original_images_A,
        clip_model,
        distributed_state.device,
        batch_size=BATCH_SIZE,
    )

    original_images_A_metrics = calculate_metrics(
        all_original_images_A,
        original_images_A_feats,
        all_original_images_A,
        [p["text"] for p in prompts_A],
        [p["text"] for p in prompts_B],
        [p["prompt"] for p in prompts_A],
        [p["prompt"] for p in prompts_B],
        "cuda",
        BATCH_SIZE,
    )

    original_images_A_df = pd.DataFrame(
        original_images_A_metrics,
    )
    original_images_A_df["Block_patched"] = ["-" for _ in range(len(prompts_A))]

    all_metrics_df = original_images_A_df

    patched_images_metrics = calculate_metrics(
        all_original_images_A,
        original_images_A_feats,
        all_patched_images,
        [p["text"] for p in prompts_A],
        [p["text"] for p in prompts_B],
        [p["prompt"] for p in prompts_A],
        [p["prompt"] for p in prompts_B],
        "cuda",
        BATCH_SIZE,
    )

    patched_images_df = pd.DataFrame(
        patched_images_metrics,
    )
    patched_images_df["Block_patched"] = ["Safe" for _ in range(len(prompts_A))]

    all_metrics_df = pd.concat([all_metrics_df, patched_images_df])

    np.save(
        os.path.join(
            SAVE_DIR,
            "Safe.npy",
        ),
        all_patched_images,
    )

    # Save DataFrame to CSV file
    all_metrics_df.to_csv(os.path.join(SAVE_DIR, "metrics.csv"))

    logging.info("Finito!")
