from typing import List, Tuple, Dict
import random
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from src.env import Env
from src.bandit.mo_bandits import (
    EGE,
    GEGE,
    MLP_EGE,
    MLP_EGE_test,
    CSR,
    LCSR,
    MLP_CSR,
    Pareto_Uni,
    Constrained_Uni,
    Pareto_GP
)
from src.utils.debug import debug_gemma_call

from src.utils.prompt_preprocess import fit_in_prompt, fit_in_examples
from src.constants import *
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt  # Add this import for plotting
EXAMPLE_TEMPLATE = "Input: [Input] \nOutput: [Output] \n\n"

import pandas as pd
import ast


class MultiObjectiveEnv(Env):
    """Environment class for multi-objective bandit optimization"""
    
    def __init__(
            self,
            bandit_choice: str = 'MO_UCB',  
            reward_methods_train: List[str] = ['0/1', 'f1'],  
            reward_methods_eval: List[str] = ['0/1', 'f1'],
            reward_weights: List[float] = None,  
            pareto_front: bool = True,  
            features: np.ndarray = None,  # Add features as an input parameter
            constraints: list = [1.0],  # Add constraints parameter
            **kwargs  
            ):
        
        # Temporarily set a valid bandit_choice for the base class initialization
        super().__init__(bandit_choice='UCB', **kwargs)
        
        # Set the correct bandit_choice after initialization
        self.bandit_choice = bandit_choice
        self.constraints = np.array(constraints)  # Store constraints
        
        # Initialize multiple reward functions
        self.reward_trains = [self.init_reward(method) for method in reward_methods_train]
        self.reward_evals = [self.init_reward(method) for method in reward_methods_eval]
        self.reward_names = reward_methods_train
        
        # Initialize weights for scalarization (if not using Pareto)
        if reward_weights is None:
            self.reward_weights = [1.0/len(reward_methods_train)] * len(reward_methods_train)
        else:
            assert len(reward_weights) == len(reward_methods_train)
            self.reward_weights = reward_weights
            
        self.pareto_front = pareto_front
        self.num_objectives = len(reward_methods_train)
        self.historical_rewards = {method: [] for method in reward_methods_train}
        self.features = features
        if self.features is None and self.bandit_choice in ["GEGE", "LCSR", 'MLP_EGE', 'MLP_EGE_test', "MLP_CSR", "Pareto_GP"]:
            if hasattr(self, 'prompts_df') and 'embedding' in self.prompts_df:
                # Convert strings of lists into actual lists
                features = np.array([ast.literal_eval(embedding) for embedding in self.prompts_df['embedding']])
                print(features.shape)
                
                # Change n_components dynamically
                if 0.8 * features.shape[0] > 64:
                    n_components = 64
                elif 0.8 * features.shape[0] > 32:
                    n_components = 32
                else:
                    n_components = features.shape[0]
                pca = PCA(n_components=n_components)
                features = pca.fit_transform(features)
                features = features / (np.linalg.norm(features, axis=1)[:, None])
                self.features = features  # shape: (num_prompts, n_components)
 
            else:
                raise ValueError("Features must be provided or self.prompts_df['embedding'] must exist.")
        
        

    def init_bandit(self, 
                    bandit_choice: str = 'MO_UCB',
                    num_arms: int = 0,
                    budget: int = 0):
        """Initialize multi-objective bandit algorithms"""
        if bandit_choice == 'EGE':
            return EGE(num_arms=num_arms, num_objectives=self.num_objectives, total_budget=budget)
        elif bandit_choice == 'Pareto_Uni':
            return Pareto_Uni(num_arms=num_arms, num_objectives=self.num_objectives)
        elif bandit_choice == 'Constrained_Uni':
            return Constrained_Uni(num_arms=num_arms, num_objectives=self.num_objectives, constraints=self.constraints)
        elif bandit_choice == 'Pareto_GP':
            return Pareto_GP(num_arms=num_arms, T=budget, num_objectives=self.num_objectives, features=self.features)
        elif bandit_choice == 'GEGE':
            return GEGE(num_arms=num_arms, T=budget, features=self.features, num_objectives=self.num_objectives)
        elif bandit_choice == 'MLP_EGE':
            return MLP_EGE(num_arms=num_arms, T=budget, features=self.features, num_objectives=self.num_objectives)
        elif bandit_choice == 'MLP_EGE_test':
            return MLP_EGE_test(num_arms=num_arms, T=budget, features=self.features, num_objectives=self.num_objectives)
        elif bandit_choice == 'CSR':
            return CSR(num_arms=num_arms, num_objectives=self.num_objectives, constraints=self.constraints, total_budget=budget)
        elif bandit_choice == 'LCSR':
            return LCSR(num_arms=num_arms, total_budget=budget, features=self.features, num_objectives=self.num_objectives, constraints=self.constraints)
        elif bandit_choice == 'MLP_CSR':
            return MLP_CSR(num_arms=num_arms, total_budget=budget, features=self.features, num_objectives=self.num_objectives, constraints=self.constraints)
        else:
            print(f"Bandit choice {bandit_choice} not recognized. ")
            return super().init_bandit(bandit_choice, num_arms, budget)
    
    def step(self, arm=None) -> Tuple[int, str, str, str, Dict[str, float]]:
        """Override step function to handle multiple rewards"""
        # Get action and response
        if arm is None:
            if self.bandit_choice == "Cluster":
                print('No Cluster')
            else:
                query, target = self.dataset.sample_batch(batch_size = 1, split = "train")[0]
                arm = self.prompt_bandit.choose_action()
                if arm == None:
                    print("No arm selected, returning None")
                    return None, None, None, None, None
                instruction = self.candidate_prompts[arm]
        else:
            query, target = self.dataset.sample_batch(batch_size = 1, split = "train")[0]
            instruction = self.candidate_prompts[arm]
            
        if self.use_examples:
            examples = [fit_in_examples(EXAMPLE_TEMPLATE, input, output) for input, output in self.examples]
        else:
            examples = []
        task_example = fit_in_examples(EXAMPLE_TEMPLATE, query, "[output]")
        if self.few_shot:
            prompt = fit_in_prompt(PROMPT_TEMPLATE, instruction, "".join(examples), task_example)
        else:
            # prompt = fit_in_prompt(PROMPT_TEMPLATE, instruction, "", task_example)
            # Old version of the response
            prompt = instruction
        response = self.LLM.get_response(prompt = "Provide only one answer and NOTHING else.\n" + prompt + " " + query, n=1)[0]
        
        # Calculate multiple rewards
        rewards = {}
        for i, (method, reward_fn) in enumerate(zip(self.reward_names, self.reward_trains)):
            reward = reward_fn(response, target)
            rewards[method] = reward  
            self.historical_rewards[method].append(reward)
            
        self.t += 1
        # print('check rewards 1: ', rewards)
            
        # Update bandit
        if self.pareto_front:
            self.prompt_bandit.update(arm, rewards)
        else:
            scalarized_reward = sum(self.reward_weights[i] * rewards[method] for i, method in enumerate(self.reward_names))
            self.prompt_bandit.update(arm, scalarized_reward)

        return arm, query, target, response, rewards
    
    def evaluation(self, num_eval_samples: int = 100, num_responses: int = 1) -> Dict[str, np.ndarray]:
        """Evaluate the bandit algorithm by running through testing dataset"""
        num_available_samples = self.dataset.__len__()[1]
        print(f"Number of available samples in the evaluation set: {num_available_samples}")
        num_eval_samples = min(num_eval_samples, num_available_samples)
        print(f"Number of samples used for evaluation: {num_eval_samples}")
        eval_set = self.dataset.sample_batch(batch_size=num_eval_samples, split="test")

        arm_eva_rewards = {method: np.zeros((self.num_prompts, num_eval_samples)) for method in self.reward_names}
        
        for j in range(self.num_prompts):
            progress_bar = tqdm(range(num_eval_samples))
            prompt = self.candidate_prompts[j]
            print(f"Evaluating arm: {j}, prompt: {prompt}")
            for i in progress_bar:
                query, target = eval_set[i]
                response = self.LLM.get_response(prompt="Provide only one answer and NOTHING else.\n" + prompt + " " + query, n=1)[0]
                
                for method, reward_fn in zip(self.reward_names, self.reward_evals):
                    if response.strip() == "":
                        reward = 0.0
                    else:
                        reward = reward_fn(response, target)
                    arm_eva_rewards[method][j, i] = reward
                
                progress_bar.set_description(f"Evaluating arm: {j}, Avg reward: {arm_eva_rewards[self.reward_names[0]][j].mean()/(i+1):.2f}")
        
        # Calculate average rewards across responses
        avg_rewards = {method: arm_eva_rewards[method].mean(axis=1) for method in self.reward_names}
        
        true_best_arms = {method: np.argmax(avg_rewards[method]) for method in self.reward_names}
        
        for method in self.reward_names:
            print(f"Evaluation best prompt for {method}: {self.candidate_prompts[true_best_arms[method]]}")
            print(f"Evaluation reward for {method}: {avg_rewards[method][true_best_arms[method]]}")
        
        # Return both average rewards and all rewards if num_responses > 1
        if num_responses > 1:
            return {
                "average_rewards": avg_rewards,
                "all_rewards": arm_eva_rewards,
            }
        else:
            return avg_rewards
    
    def best_arm(self):
        pareto_arms = self.prompt_bandit.best_arm()
        if self.bandit_choice == "Cluster" and self.current_phase == 2:
            pareto_arms = self.active_prompts[pareto_arms]
        return pareto_arms

    ###### Asyncronized functions for ChatGPT ######
    
    async def async_step(self, arm=None):
        # - choose a query (utils)
        # - choose a prompt (bandit)
        # - get response (LLM)
        # - evaluate (utils)
        # - update bandits (bandit)
        
        if arm is None:
            if self.bandit_choice == "Cluster":
                print('No Cluster')
            else:
                query, target = self.dataset.sample_batch(batch_size = 1, split = "train")[0]
                arm = self.prompt_bandit.choose_action()
                if arm == None:
                    print("No arm selected, returning None")
                    return None, None, None, None, None
                instruction = self.candidate_prompts[arm]
        else:
            query, target = self.dataset.sample_batch(batch_size = 1, split = "train")[0]
            instruction = self.candidate_prompts[arm]

        if self.use_examples:
            examples = [fit_in_examples(EXAMPLE_TEMPLATE, input, output) for input, output in self.examples]
        else:
            examples = []
        task_example = fit_in_examples(EXAMPLE_TEMPLATE, query, "[Output]")
        if self.few_shot:
            prompt = fit_in_prompt(PROMPT_TEMPLATE, instruction, "".join(examples), task_example)
        else:
            prompt = fit_in_prompt(PROMPT_TEMPLATE, instruction, "", task_example)
            # Old version of the response
        response = await self.LLM.get_response(prompt = "Provide only one answer and NOTHING else." + prompt, n=1)
        response = response[0]

        # Calculate multiple rewards
        rewards = {}
        for i, (method, reward_fn) in enumerate(zip(self.reward_names, self.reward_trains)):
            reward = reward_fn(response, target)
            rewards[method] = reward 
            self.historical_rewards[method].append(reward)
            
        self.t += 1
        # print('check rewards 1: ', rewards)
            
        # Update bandit
        if self.pareto_front:
            self.prompt_bandit.update(arm, rewards)
        else:
            scalarized_reward = sum(self.reward_weights[i] * rewards[method] for i, method in enumerate(self.reward_names))
            self.prompt_bandit.update(arm, scalarized_reward)

        return arm, query, target, response, rewards
    
    async def evaluation_async(self, num_eval_samples: int = 100, num_responses: int = 1) -> Dict[str, np.ndarray]:
        # Evaluate the bandit algorithm by running through testing dataset
        # num_correct = 0
        num_available_samples = self.dataset.__len__()[1]
        # assert num_eval_samples <= num_available_samples, "Not enough samples existing in the evaluation set."
        num_eval_samples = min(num_eval_samples, num_available_samples)
        eval_set = self.dataset.sample_batch(batch_size=num_eval_samples, split="test")

        arm_eva_rewards = {method: np.zeros((self.num_prompts, num_eval_samples)) for method in self.reward_names}
        
        for j in range(self.num_prompts):
            progress_bar = tqdm(range(num_eval_samples))
            prompt = self.candidate_prompts[j]
            print(f"Evaluating arm: {j}, prompt: {prompt}")
            for i in progress_bar:
                query, target = eval_set[i]
                response = await self.LLM.get_response(prompt="Provide only one answer and NOTHING else.\n" + prompt + " " + query, n=1)
                response = response[0]
                
                for method, reward_fn in zip(self.reward_names, self.reward_evals):
                    if response.strip() == "":
                        reward = 0.0
                    else:
                        reward = reward_fn(response, target)
                    arm_eva_rewards[method][j, i] = reward
                
                progress_bar.set_description(f"Evaluating arm: {j}, Avg reward: {arm_eva_rewards[self.reward_names[0]][j].mean()/(i+1):.2f}")
        
        # Calculate average rewards across responses
        avg_rewards = {method: arm_eva_rewards[method].mean(axis=1) for method in self.reward_names}
        
        true_best_arms = {method: np.argmax(avg_rewards[method]) for method in self.reward_names}
        
        for method in self.reward_names:
            print(f"Evaluation best prompt for {method}: {self.candidate_prompts[true_best_arms[method]]}")
            print(f"Evaluation reward for {method}: {avg_rewards[method][true_best_arms[method]]}")
        
        # Return both average rewards and all rewards if num_responses > 1
        if num_responses > 1:
            return {
                "average_rewards": avg_rewards,
                "all_rewards": arm_eva_rewards,
            }
        else:
            return avg_rewards
        
        return arm_eva_rewards/num_eval_samples