import os
import json
import logging
import argparse
from tqdm import trange

import torch
import numpy as np
from datasets import load_dataset
from diffusers import StableDiffusionPipeline
from diffusers.utils import is_wandb_available
from transformers import BlipProcessor, BlipForImageTextRetrieval

if is_wandb_available():
    import wandb

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation of the generation model.")
    # Model params
    parser.add_argument(
        "--diffusion_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to diffusion model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--itm_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to image-text matching model or  model identifier from huggingface.co/models.",
    )
    # Dataset params
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    # Generation params
    parser.add_argument(
        "--batch_size", 
        type=int, 
        default=4, 
        help="Batch size (per device) for the generation and score computation."
    )
    parser.add_argument(
        "--num_inference_steps", 
        type=int, 
        default=25, 
        help="Number of inference steps for image generation."
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=0, 
        help="Random seed."
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None, 
        help="Directory to save results."
    )
    parser.add_argument(
        "--dtype", 
        type=str, 
        default="float16", 
        choices=["float16", "float32"],
        help="Inference dtype."
    )
    parser.add_argument(
        "--log_wandb", 
        action="store_true",
        help="Whether to log to W&B."
    )


    args = parser.parse_args()
    return args


logger = logging.Logger(__name__, level=logging.INFO)


def main():
    args = parse_args()

    dtype = getattr(torch, args.dtype)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    if args.log_wandb:
        assert is_wandb_available()
        wandb.init(config=args)

    ### Data ### 

    # load dataset
    dataset = load_dataset(args.dataset_name)
     # extract prompts
    prompts = dataset['train']['text']

    #### Models ###

    # create pipeline
    pipeline = StableDiffusionPipeline.from_pretrained(
        args.diffusion_model_name_or_path, 
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipeline.set_progress_bar_config(disable=True)
    # create image-text matching model
    itm_model = BlipForImageTextRetrieval.from_pretrained(
        args.itm_model_name_or_path
    ).to(device=device, dtype=dtype)
    itm_processor = BlipProcessor.from_pretrained(args.itm_model_name_or_path)
    # fix generator seed
    generator = torch.Generator(device=device).manual_seed(args.seed)

    itm_scores = []
    cosine_scores = []
    logger.info('Generating images and computing image-text matching scores.')
    with torch.inference_mode():
        for sample_id in trange(0, len(prompts), args.batch_size):
            sample_prompts = prompts[sample_id: sample_id + args.batch_size]
            # generate images
            generated_images = pipeline(
                sample_prompts, 
                num_inference_steps=args.num_inference_steps, 
                generator=generator
            ).images[0]
            # prepare BLIP inputs
            itm_inputs = itm_processor(
                text=sample_prompts, 
                images=generated_images, 
                return_tensors="pt", 
                padding=True
            ).to(device)
            # compute itm scores
            itm_scores.append(itm_model(**itm_inputs)[0].softmax(dim=-1)[:, 1].cpu().float().numpy())
            # compute cosine similarity scores
            cosine_scores.append(itm_model(**itm_inputs, use_itm_head=False)[0].diag().cpu().float().numpy())
    
    # get final scores
    itm_scores_mean = np.mean(np.concatenate(itm_scores))
    cosine_scores_mean = np.mean(np.concatenate(cosine_scores))
    log_dict = {'itm_score': itm_scores_mean, 'cosine_score': cosine_scores_mean}

    logger.info(f'ITM score {itm_scores_mean:.3f}')
    logger.info(f'Cosine similarity {cosine_scores_mean:.3f}')

    if args.log_wandb:
        wandb.log(log_dict)

    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
            json.dump(log_dict, f)

if __name__ == "__main__":
    main()