from dotenv import load_dotenv
import argparse
import json
from textpe.utils.text import text
from textpe.utils.image import data_from_dataset
from pe.logging import setup_logging, execution_logger
from pe.runner import PE
from pe.population import PEPopulation
from pe.api.text import LLMAugPE
from pe.llm import OpenAILLM, HuggingfaceLLM
from pe.embedding.text import SentenceTransformer
from pe.embedding.image import Inception
from textpe.utils.embedding import hfpipe_embedding
from pe.histogram import NearestNeighbors
from textpe.utils.histogram import ImageVotingNN
from pe.callback import SaveCheckpoints
from pe.callback import ComputeFID
from textpe.utils.callbacks import _ComputeFID
from pe.callback import SaveTextToCSV
from pe.logger import CSVPrint
from pe.logger import LogPrint
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME

import pandas as pd
import os
import sys
import numpy as np
from torchvision.datasets import LSUN
from torchvision import transforms


pd.options.mode.copy_on_write = True
IMAGE_SIZE = 256

transform = transforms.Compose([transforms.Resize(IMAGE_SIZE),transforms.CenterCrop(IMAGE_SIZE)])

def main(args, config):
    
    exp_folder = args.output
    
    current_folder = os.path.dirname(os.path.abspath(__file__))

    load_dotenv()

    setup_logging(log_file=os.path.join(exp_folder, "log.txt"))

    execution_logger.info("\nExecuting {}...\ninput: {}\npe llm: {}\noutput: {}".format(sys.argv[0],args.data,args.llm,args.output))


    data = text(root_dir=args.data)
    dataset = LSUN("dataset/lsun",classes=['bedroom_train'],transform=transform)
    data_from_lsun = data_from_dataset(dataset,length=300000)

    if args.llm=='huggingface':
        llm = HuggingfaceLLM(**config["model"]["Huggingface"])
    elif args.llm=='openai':
        llm = OpenAILLM(**config["model"]["OpenAI"])
    else:
        raise ValueError("llm argument not recognized.")
    
    api = LLMAugPE(
        llm=llm,
        random_api_prompt_file=os.path.join(current_folder, config["api_prompt"]['random']),
        variation_api_prompt_file=os.path.join(current_folder, config["api_prompt"]['variation']),
        min_word_count=25,
        word_count_std=36,
        blank_probabilities=0.5
    )
    # embedding = SentenceTransformer(model="sentence-t5-base")
    embedding_syn = hfpipe_embedding(model="stabilityai/sdxl-turbo")
    # embedding_priv = Inception(res=256,batch_size=16)
    histogram = ImageVotingNN(
        api=api,
        embedding=embedding_syn,
        mode="L2",
        lookahead_degree=8,
        priv_dataset=data_from_lsun
    )
    population = PEPopulation(
        api=api, keep_selected=True, selection_mode="rank"
    )

    save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint"))
    compute_fid_vote = _ComputeFID(priv_data=data_from_lsun, embedding=embedding_syn, filter_criterion={VARIATION_API_FOLD_ID_COLUMN_NAME: -1})
    save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text"))

    csv_print = CSVPrint(output_folder=exp_folder)
    log_print = LogPrint()

    num_private_samples = len(data.data_frame)
    delta = 1.0 / num_private_samples / np.log(num_private_samples)

    pe_runner = PE(
        priv_data=data,
        population=population,
        histogram=histogram,
        callbacks=[save_checkpoints, save_text_to_csv, compute_fid_vote],
        loggers=[csv_print, log_print],
    )
    pe_runner.run(
        num_samples_schedule=[2000] * 21,
        delta=delta,
        epsilon=1.0,
        # noise_multiplier=0,
        checkpoint_path=os.path.join(exp_folder, "checkpoint"),
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--output',type=str,default="results/text")
    parser.add_argument('--data',type=str,default="lsun/bedroom_train")
    parser.add_argument('--llm',type=str,choices=['openai','huggingface'],default='huggingface')

    args = parser.parse_args()

    with open("textpe/config.json",'r',encoding='utf-8') as f:
        config = json.load(f)

    main(args, config)