import os
import pandas as pd
import os
os.environ["WANDB_MODE"] = "dryrun"
import wandb
import json
import numpy as np
import re
import time
import fire
import asyncio
import tiktoken
import matplotlib.pyplot as plt
import matplotlib

from sklearn.cluster import KMeans
from tqdm import tqdm
from openai import OpenAI
from src.utils.utils import set_random_seeds, get_embedding
from src.MOenv import MultiObjectiveEnv as Env
from src.utils.dataloader import InductionDataset, BigbenchDataset
from src.constants import DEFAULT_OPENAI, INDUCTION_TASKS, BIGBENCH_TASKS
from sklearn.manifold import TSNE
from wandb import AlertLevel



def main(task: str = "informal_to_formal",
         num_prompts: int = 10,
         llm_choice: str = "ChatGPT",
         config_path: str = "./config.json",
         num_clusters: int = 3,
         save_dir: str = 'results/MO/') -> None:
    # Load config from file
    with open(config_path, "r") as f:
        config = json.load(f)
        
    embedding_model = "text-embedding-ada-002"
    embedding_encoding = "cl100k_base"
    max_tokens = 8000

    print(f"Task: {task}")
    start_time = time.time()

    config["task"] = task
    config["bandit_choice"] = "UCB"
    config["use_examples"] = True
    config["use_rephrases"] = False

    # Get the number of samples in the training dataset
    # if task in INDUCTION_TASKS:
    #     dataset = InductionDataset(task)
    #     task_set = "instruction_induction"
    # elif task in BIGBENCH_TASKS:
    #     dataset = BigbenchDataset(task)
    #     task_set = "bigbench-ii"
    # else:
    #     print("Invalid task!")
    #     return

    # train_sample_size, eval_sample_size = dataset.get_len()
    num_prompts_examples = num_prompts
    # if task in BIGBENCH_TASKS:
    #     num_examples = min(int(train_sample_size * 0.5), 10)
    # else:
    #     num_examples = min(int(train_sample_size * 0.5), 100)

    num_samples = 10 * num_prompts_examples

    config["T"] = num_samples
    if task in ["rhymes", "word_in_context", "common_concept", "object_counting"] or "translation" in task:
        config["reward_method_train"] = ["multi_ans_f1"]
        config["reward_method_eval"] = ["multi_ans_f1"]

    markers = ['o', 's', '^', 'D', 'p', '*', 'h', 'v']
    colors = ['blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']

    run = wandb.init(project="prompt_bandit_gen", config=config, name=f"task_{task}_baseline")

    # Load existing prompts if prompts.csv exists
    storage_path = save_dir + task + '_' + llm_choice + '/'
    os.makedirs(storage_path, exist_ok=True)
    existing_prompts = []
    if os.path.exists(os.path.join(storage_path, "prompts.csv")):
        prompts_df = pd.read_csv(os.path.join(storage_path, "prompts.csv"))
        existing_prompts = prompts_df["prompts"].tolist()
        print(f"Loaded {len(existing_prompts)} existing prompts from prompts.csv.")
    else:
        prompts_df = pd.DataFrame(
            columns=["prompts", "n_tokens", "embedding", "clusters"] + [f"scores_{method}" for method in config["reward_method_eval"]] + ["scores_std", "relative_std"] + [f"mean_scores_{method}" for method in config["reward_method_eval"]] + [f"std_scores_{method}" for method in config["reward_method_eval"]]
        )  # Initialize an empty DataFrame

    # Initialize the environment with existing prompts
    env = Env(
        bandit_choice="UCB",
        T=150,
        advanced_method=None,
        reward_methods_train=config["reward_method_train"],
        reward_methods_eval=config["reward_method_eval"],
        LLM_choice=llm_choice,
        use_examples=True,
        num_examples=5,
        num_prompts_examples=num_prompts,
        task=task,
        load_from_file=False,
        random_seed=43,
        pareto_front=True,
        existing_prompts=existing_prompts  # Pass existing prompts
    )
    asyncio.run(env.async_init())

    # Generate new candidate prompts
    candidate_prompts = env.candidate_prompts
    new_prompts_df = pd.DataFrame(candidate_prompts, columns=["prompts"])
    embedding_client = OpenAI(api_key=DEFAULT_OPENAI.API_KEY)
    encoding = tiktoken.get_encoding(embedding_encoding)
    new_prompts_df["n_tokens"] = new_prompts_df["prompts"].apply(lambda x: len(encoding.encode(x)))
    new_prompts_df["embedding"] = new_prompts_df["prompts"].apply(lambda x: get_embedding(x, embedding_client, model=embedding_model))
    embeddings = new_prompts_df["embedding"].to_list()

    # Perform clustering on new prompts
    kmean = KMeans(n_clusters=num_clusters, random_state=0)
    clusters = kmean.fit_predict(embeddings)
    new_prompts_df["clusters"] = clusters



    # Save the updated DataFrame to prompts.csv
    # prompts_df.to_csv(os.path.join(storage_path, "prompts.csv"), index=False)
    print(f"Updated prompts.csv with {len(new_prompts_df)} new prompts.")

    # matrix = new_prompts_df["embedding"].to_list()
    # matrix = np.array(matrix)
    # tsne = TSNE(n_components=2, perplexity=5, random_state=42, init='random', learning_rate=200)
    # vis_dims = tsne.fit_transform(matrix)

    # x = [x for x, y in vis_dims]
    # y = [y for x, y in vis_dims]

    if llm_choice == "ChatGPT":
        env.reconnect_LLM()
        
    storage_path = save_dir + task + '_' + llm_choice + '/'
    if not os.path.exists(storage_path):
        os.makedirs(storage_path)
    # Evaluate new prompts and update scores
    if llm_choice == "ChatGPT":
        eva_reward = asyncio.run(env.evaluation_async(num_eval_samples=50, num_responses=1))
    else:
        eva_reward = env.evaluation(num_eval_samples=50, num_responses=1)
    print(eva_reward)
    print(type(eva_reward))
    for method in config["reward_method_eval"]:
        prompt_scores = eva_reward[method]
        new_prompts_df[f"scores_{method}"] = prompt_scores.tolist()
        score_std = np.std(prompt_scores)
        relative_std = score_std / np.mean(prompt_scores)

        fig = plt.figure(figsize=(10, 10))
        plt.bar(np.arange(len(prompt_scores)), prompt_scores, align='center', alpha=0.5)
        plt.xlabel('Prompts')
        plt.ylabel(f'Eval Scores ({method})')
        plt.title(f'Evaluation Scores for {task} ({method})')
        run.log({f"Eval reward on task {task} ({method})": wandb.Image(plt)})

        if score_std < 0.3:
            print(f"Low variance prompts don't exist for {method}! Creating...")
        else:
            print(f"High variance prompts don't exist for {method}! Creating...")
            fig = plt.figure(figsize=(10, 10))
            for cluster in np.unique(clusters):
                idxs = np.array(clusters) == cluster
                plt.scatter(np.array(x)[idxs], np.array(y)[idxs], c=colors[cluster % len(colors)], s=np.array(prompt_scores)[idxs] * 500,
                            alpha=0.5, marker=markers[cluster % len(markers)], label=f'Cluster {cluster}')

            best_prompt_idx = np.argmax(prompt_scores)
            plt.scatter(x[best_prompt_idx], y[best_prompt_idx], c='red', s=prompt_scores[best_prompt_idx] * 500,
                        alpha=1, marker=markers[clusters[best_prompt_idx] % len(markers)], edgecolors='black',)

            plt.colorbar(label='Clusters')
            plt.xlabel('X')
            plt.ylabel('Y')
            plt.legend(title=f'Scatter Plot for {task} ({method})')

            run.log({f"Cluster plot on task {task} with importance ({method})": wandb.Image(plt)})
            plt.savefig(os.path.join(storage_path, f"cluster_plot_{method}.png"))
            run.alert(title=f"Found a useful sample for task {task} ({method})!", text=f"The found relative variance is {relative_std}, the prompt scores are {prompt_scores}", level=AlertLevel.INFO)
            plt.close()

    # Append new prompts to the existing DataFrame
    prompts_df = pd.concat([prompts_df, new_prompts_df], ignore_index=True)

    # Save combined CSV file
    prompts_df.to_csv(os.path.join(storage_path, "prompts.csv"))
    # Test print
    print(prompts_df.loc[:, [f"scores_{method}" for method in config["reward_method_eval"]]])

    # Additional scatter plot for two reward methods
    if len(config["reward_method_eval"]) == 2:
        method1, method2 = config["reward_method_eval"]
        scores1 = prompts_df[f"scores_{method1}"]
        scores2 = prompts_df[f"scores_{method2}"]

        fig = plt.figure(figsize=(10, 10))
        plt.scatter(scores1, scores2, alpha=0.5)
        plt.xlabel(f'Scores ({method1})')
        plt.ylabel(f'Scores ({method2})')
        plt.title(f'Scatter Plot of Scores for {task} ({method1} vs {method2})')
        run.log({f"Scatter plot of scores for {task} ({method1} vs {method2})": wandb.Image(plt)})
        plt.savefig(os.path.join(storage_path, f"scatter_plot_{method1}_vs_{method2}.png"))
        plt.close()

    run.finish()

if __name__ == "__main__":
    fire.Fire(main)
