# -*- coding: utf-8 -*-
""" Generate samples of prompts from dsg benchmark.
"""

import os
import argparse
import pandas as pd
from tqdm import tqdm
from typing import Dict, Callable
from dsg.model.generative_model import DiffusionModel, Janus_Pro
from dsg.model.qwen_model import Qwen2_5, Qwen2_VL
from dsg.csv_loader import load_dsg_file
from dsg.vqa_util import vqa
from dsg.dsg_util import generate_new_prompt, decorate_new_prompt
from dsg.build_examples import get_tifa_examples


def save_data(dsg_data, file_path): 
    dataframe = []
    for item in dsg_data["text"].keys(): 
        for qid in dsg_data["text"][item].keys(): 
            line_data = {}
            for column_id in dsg_data.keys(): 
                line_data[column_id] = dsg_data[column_id][item][qid]
            dataframe.append(line_data)
    
    df = pd.DataFrame(dataframe)
    df.to_csv(file_path, index=False)


def gen_image(
    model: Callable, 
    prompts: Dict[str, Dict[int, str]], 
    img_dir: str, 
): 
    for k, v in tqdm(prompts.items()): 
        prompt = str(list(v.values())[0])
        img = model(prompt)
        img.save(os.path.join(img_dir, k + ".png"))


def main(args): 
    # suppose that you have generate the dsg file (we provide it in the data directory, or see https://github.com/j-min/DSG)
    assert args.data_file.endswith(".csv")
    dsg_data = load_dsg_file(args.data_file)

    if args.model_name == "sd15": 
        model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
        pipeline = DiffusionModel(model_id)
    elif args.model_name == "sd21": 
        model_id = "stabilityai/stable-diffusion-2-1"
        pipeline = DiffusionModel(model_id)
    elif args.model_name == "flux": 
        model_id = "black-forest-labs/FLUX.1-dev"
        pipeline = DiffusionModel(model_id)
    elif args.model_name == "janus-pro": 
        pipeline = Janus_Pro("deepseek-community/Janus-Pro-7B")
    else: 
        raise ValueError(f"Unknown model name: {args.model}")
    
    # fixed llm and vlm
    llm_model = Qwen2_5("Qwen/Qwen2.5-14B-Instruct")
    vlm_model = Qwen2_VL("Qwen/Qwen2-VL-7B-Instruct")


    
    ##################################################
    # Generate original images with texts
    ##################################################
    if not os.path.exists(args.img_dir): 
        print("Start generating images with original prompts... ")
        os.makedirs(args.img_dir)    
        pipeline.to("cuda")
        dsg_id2prompt = dsg_data["text"]
        gen_image(pipeline, dsg_id2prompt, args.img_dir)
        pipeline.to("cpu") # off load
    else: 
        print("Path of original images exists. Skip this stage. ")
    


    ##################################################
    # VQA on generated images
    ##################################################
    if not "answer" in dsg_data: 
        print("Start VQA on original images... ")
        vlm_model.to("cuda")
        dsg_id2question = dsg_data["question_natural_language"]
        dsg_id2dependency = dsg_data["dependency"]
        dsg_data["answer"], dsg_data["valid"] = vqa(vlm_model, dsg_id2question, dsg_id2dependency, args.img_dir)
        vlm_model.to("cpu")
        save_data(dsg_data, args.output_file)
    else: 
        print("VQA answers about the original images exist. Skip this stage. ")
    


    ##################################################
    # Rewritting prompts
    ##################################################
    if not "rewritten_text" in dsg_data: 
        print("Start rewritting prompts... ")
        llm_model.to("cuda")

        expansion_examples = get_tifa_examples(task='expansion')
        rewritting_examples = get_tifa_examples(task='rewritting')

        ### bulid input ###
        data = {}
        for item in dsg_data["text"].keys(): 
            tuple_dict = dsg_data["tuple"][item]
            tuples = []
            for k, v in tuple_dict.items(): 
                input_line = f"{k} | {v}"
                input_line = " ".join(input_line.split())
                tuples += [input_line]
            
            answer_dict = dsg_data["answer"][item]
            answers = []
            for k, v in answer_dict.items(): 
                input_line = f"{k} | {v}"
                input_line = " ".join(input_line.split())
                answers += [input_line]
            
            prompt = list(dsg_data["text"][item].values())[0]
            data[item] = {"input": prompt, "tuple": "\n".join(tuples), "answer": "\n".join(answers)}
        
        dsg_id2expansion, dsg_id2rewrittentext = generate_new_prompt(
            data, 
            llm_model, 
            expansion_examples, 
            rewritting_examples, 
            verbose=True,
        )
        llm_model.to("cpu")

        # copy data
        dsg_data["expansion_tuple"] = {}
        dsg_data["rewritten_text"] = {}
        for item in dsg_data["text"].keys(): 
            dsg_data["expansion_tuple"][item] = {}
            dsg_data["rewritten_text"][item] = {}

            for qid in dsg_data["text"][item].keys(): 
                dsg_data["expansion_tuple"][item][qid] = dsg_id2expansion[item]["output"]
                dsg_data["rewritten_text"][item][qid] = dsg_id2rewrittentext[item]["output"]
        
        save_data(dsg_data, args.output_file)
    else: 
        print("Rewritting prompts exist. Skip this stage. ")


    ##################################################
    # Decorating prompts
    ##################################################
    if args.decorating: 
        if not "decorated_text" in dsg_data: 
            print("Start decorating prompts... ")
            llm_model.to("cuda")

            decorating_examples = get_tifa_examples(task='decorating')

            ### bulid input ###
            data = {}
            for item in dsg_data["text"].keys(): 
                prompt = list(dsg_data["text"][item].values())[0]
                data[item] = {"input": prompt}
            
            dsg_id2optimizedtext = decorate_new_prompt(
                data, 
                llm_model, 
                decorating_examples, 
                verbose=True,
            )
            llm_model.to("cpu")

            # copy data
            dsg_data["decorated_text"] = {}
            for item in dsg_data["text"].keys(): 
                dsg_data["decorated_text"][item] = {}

                for qid in dsg_data["text"][item].keys(): 
                    dsg_data["decorated_text"][item][qid] = dsg_id2optimizedtext[item]["output"]
            
            save_data(dsg_data, args.output_file)
        else: 
            print("Decorating prompts exist. Skip this stage. ")
    else:
        print("Do not decorate the prompts. Skip this stage. ")
    


    ##################################################
    # Generate new images with final prompts
    ##################################################
    if not os.path.exists(args.optimized_img_dir): 
        print("Start generating images with rewritten prompts... ")
        os.makedirs(args.optimized_img_dir)    
        pipeline.to("cuda")
        dsg_id2prompt = dsg_data["decorated_text"] if args.decorating else dsg_data["rewritten_text"]
        gen_image(pipeline, dsg_id2prompt, args.optimized_img_dir)
        pipeline.to("cpu") # off load
    else: 
        print("Path of re-generated images exists. Skip this stage. ")
    

    ##################################################
    # VQA on generated images
    ##################################################
    if not "rewritten_answer" in dsg_data: 
        print("Start VQA on original images... ")
        vlm_model.to("cuda")
        dsg_id2question = dsg_data["question_natural_language"]
        dsg_id2dependency = dsg_data["dependency"]
        dsg_data["rewritten_answer"], dsg_data["rewritten_valid"] = vqa(vlm_model, dsg_id2question, dsg_id2dependency, args.optimized_img_dir)
        vlm_model.to("cpu")
        save_data(dsg_data, args.output_file)
    else: 
        print("VQA answers about the re-generated images exist. Skip this stage. ")


if __name__ == "__main__": 
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="sd15", help='Generative model. ')
    parser.add_argument('--data_file', type=str, default="./sdv21_new_output.csv", help='Directory of data files. ')
    parser.add_argument('--output_file', type=str, default="./sdv21_new_output.csv", help='Directory of output files. ')
    parser.add_argument('--img_dir', type=str, default="./img_dir", help='Directory of output images. Path should not exists before generating. ')
    parser.add_argument('--optimized_img_dir', type=str, default="./optimized_img_dir", help='Directory of optimized images. Path should not exists before generating. ')
    parser.add_argument('--decorating', action='store_true', help='Decorating the optimized prompts with special words. ')
    args = parser.parse_args()

    main(args)

