import pandas as pd
import wandb
import json
import numpy as np
import re
import time
import fire
import os
os.environ["WANDB_MODE"] = "dryrun"
import asyncio
import tiktoken
import matplotlib.pyplot as plt
import matplotlib

from tqdm import tqdm
from openai import OpenAI
from src.utils.utils import set_random_seeds, get_embedding
from src.MOenv import MultiObjectiveEnv as Env
from src.utils.dataloader import InductionDataset as Dataset
from src.constants import *
from sklearn.manifold import TSNE



def main(task: str = "informal_to_formal",
         budget: int = 100,
         bandit_algorithm: str = "UCB",
         config_path: str = "./config.json",
         storage_path: str = "./results/MO/informal_to_formal_WhiteBox_Gemma/",
         random_seed: int = 42,
         llm: str = "WhiteBox",
         constraints: str = [1.0],
         arm_set: str = "all",
         ) -> None:
    # Parse the constraints string into a Python list
    
    # Load config from file
    with open(config_path, "r") as f:
        config = json.load(f)
        
    if arm_set == "all":
        arm_set = storage_path + "prompts.csv"
    print(arm_set)
        
    # Load examples from file
    if os.path.isfile(arm_set):
        print('read csv')
        prompts_df = pd.read_csv(arm_set)
        prompts = prompts_df["prompts"].to_list()
        prompt_scores = {}
        for method in config["reward_method_eval"]:
            prompt_scores[method] = prompts_df[f"mean_scores_{method}"].to_numpy()
        best_prompt = prompts[np.argmax([np.mean(prompt_scores[method]) for method in config["reward_method_eval"]])]
        best_score = np.max([np.mean(prompt_scores[method]) for method in config["reward_method_eval"]])
        num_prompts = len(prompts)
        total_budget = num_prompts * budget
    if task in ["rhymes", "word_in_context", "common_concept"] or "translation" in task:
        config["reward_method_train"] = ["multi_ans_f1"]
        config["reward_method_eval"] = ["multi_ans_f1"]
                
    if bandit_algorithm == "Cluster":
        config["prompt_bandit_choice"] = "SequentialHalving"
        config["cluster_bandit_choice"] = "SequentialHalving"
        config["with_cluster"] = True
    else:
        config["prompt_bandit_choice"] = bandit_algorithm
        config["cluster_bandit_choice"] = "SequentialHalving"
        config["with_cluster"] = False
    config["random_seed"] = random_seed
    config["LLM_choice"] = llm
    
    print(total_budget)
    env = Env(
            bandit_choice=bandit_algorithm,
            T=total_budget,
            reward_methods_train=config["reward_method_train"],
            reward_methods_eval=config["reward_method_eval"],
            LLM_choice=llm,
            use_examples=False,
            num_examples=0,
            num_prompts_examples=0,
            task=task,
            load_from_file=True,
            file_path=arm_set,
            random_seed=random_seed,
            constraints=constraints,  # Pass constraints here
            )
    asyncio.run(env.async_init())
    
    config["task"] = task
    config["T"] = total_budget
    dataset = storage_path.split("/")[-4]
    # Initialize wandb    
    run = wandb.init(project=f"{dataset}_banditonly", config=config, name=f"task_{task}_bandit_{bandit_algorithm}_budget_{budget}_seed_{random_seed}")
          
    progress_bar = tqdm(range(total_budget), desc=f"Pareto arms: {env.best_arm()}, Training reward: 0.00")
    running_average_reward = 0  # Variable to store the running average efficiency

    for i in progress_bar:
        if config["LLM_choice"] == "ChatGPT":
            _, _, _, _, rewards = asyncio.run(env.async_step())
        else:
            _, _, _, _, rewards = env.step()
        if rewards is None:
            print("No rewards received.")
            break
        progress_bar.set_description(f"Bandit best arm: {env.best_arm()}, Training reward: {running_average_reward:.2f}")
        run.log({"running_average_reward": running_average_reward})

        if (i + 1) % 10 == 0:
            run.log({"current_best_arm": env.best_arm()})
    
    if bandit_algorithm in ["EGE", "GEGE", "MLP_EGE", "MLP_EGE_test", "Pareto_Uni", "Pareto_GP"]:
        bandit_best_arm = env.prompt_bandit.pareto_front
    elif bandit_algorithm in ["CSR", "LCSR", "Constrained_Uni", "MLP_CSR"]:
        bandit_best_arm = [env.prompt_bandit.best_arm()]
    bandit_best_prompt = [prompts[arm] for arm in bandit_best_arm]
    bandit_best_score = {method: [prompt_scores[method][arm] for arm in bandit_best_arm] for method in config["reward_method_eval"]}
    
    table = wandb.Table(columns=["task", "bandit_algorithm", "budget", "random_seed", "bandit_best_prompt", "bandit_best_arm", "bandit_best_score"],
                        data=[[task, bandit_algorithm, budget, random_seed, bandit_best_prompt, bandit_best_arm, bandit_best_score]])
    run.log({"Experiment_results": table})
    
    df = pd.DataFrame(columns=["task", "bandit_algorithm", "budget", "random_seed", "bandit_best_prompt", "bandit_best_arm", "bandit_best_score"],
                      data=[[task, bandit_algorithm, budget, random_seed, bandit_best_prompt, bandit_best_arm, bandit_best_score]])
    if os.path.isfile(f"{storage_path}bandit_results.csv"):
        df.to_csv(f"{storage_path}bandit_results.csv", mode='a', header=False, index=False)
    else:
        df.to_csv(f"{storage_path}bandit_results.csv", index=False)
    
    run.finish()

if __name__ == "__main__":
    # Parse command line arguments using fire
    fire.Fire(main)
