import pandas as pd
import os
os.environ["WANDB_MODE"] = "dryrun"
import wandb
import json
import numpy as np
import re
import time
import fire
import asyncio
import tiktoken
import matplotlib.pyplot as plt
import matplotlib
import pickle  # Add this import at the top of the file if not already present

from sklearn.cluster import KMeans
from tqdm import tqdm
from openai import OpenAI
from src.utils.utils import set_random_seeds, get_embedding
from src.MOenv import MultiObjectiveEnv as Env
from src.utils.dataloader import InductionDataset, BigbenchDataset, SumDataset
from src.constants import *
from src.utils.utils import set_random_seeds
from sklearn.manifold import TSNE
from wandb import AlertLevel



def main(task: str = "informal_to_formal",
         num_prompts: int = 10,
         llm_choice: str = "ChatGPT",
         config_path: str = "./config.json",
         num_clusters: int = 3,
         save_dir: str = 'results/MO/') -> None:
    # Load config from file
    with open(config_path, "r") as f:
        config = json.load(f)
    storage_path = save_dir + task + '_' + llm_choice + '/'
    if not os.path.exists(storage_path):
        os.makedirs(storage_path)



    embedding_model = "text-embedding-ada-002"
    embedding_encoding = "cl100k_base"
    max_tokens = 8000

    print(f"Task: {task}")
    start_time = time.time()

    config["task"] = task
    config["bandit_choice"] = "UCB"
    config["use_examples"] = True
    config["use_rephrases"] = False

    markers = ['o', 's', '^', 'D', 'p', '*', 'h', 'v']
    colors = ['blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']

    run = wandb.init(project="prompt_bandit_gen", config=config, name=f"task_{task}_baseline")
    env = Env(
        bandit_choice="UCB",
        T=150,
        advanced_method=None,
        reward_methods_train=config["reward_method_train"],
        reward_methods_eval=config["reward_method_eval"],
        LLM_choice=llm_choice,
        use_examples=False,
        num_examples=5,
        num_prompts_examples=num_prompts,
        task=task,
        load_from_file=True,
        file_path=os.path.join(storage_path, "prompts.csv"),
        random_seed=43,
        pareto_front=True,
    )
    asyncio.run(env.async_init())

    candidate_prompts = env.candidate_prompts
    prompts_df = pd.DataFrame(candidate_prompts, columns=["prompts"])
    embedding_client = OpenAI(api_key=DEFAULT_OPENAI.API_KEY)
    encoding = tiktoken.get_encoding(embedding_encoding)
    prompts_df["n_tokens"] = prompts_df["prompts"].apply(lambda x: len(encoding.encode(x)))
    prompts_df["embedding"] = prompts_df["prompts"].apply(lambda x: get_embedding(x, embedding_client, model=embedding_model))
    embeddings = prompts_df["embedding"].to_list()

    if llm_choice == "ChatGPT":
        env.reconnect_LLM()

    for seed in range(43, 44):
        set_random_seeds(seed)
        if llm_choice == "ChatGPT":
            eva_reward = asyncio.run(env.evaluation_async(num_eval_samples=200, num_responses=2))
        else:
            eva_reward = env.evaluation(num_eval_samples=200, num_responses=2)  # Example: num_responses set to 5

        for method in config["reward_method_eval"]:
            prompt_scores_mean = eva_reward["average_rewards"][method]
            prompt_scores_std = eva_reward["all_rewards"][method].std(axis=1)

            prompts_df[f"mean_scores_{method}"] = prompt_scores_mean.tolist()
            prompts_df[f"std_scores_{method}"] = prompt_scores_std.tolist()

    # Save combined CSV file
    prompts_df.to_csv(os.path.join(storage_path, f"prompts_{seed}.csv"))
    
    run.finish()

if __name__ == "__main__":
    fire.Fire(main)
