import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch.nn.functional as F
import numpy as np 
from LLM_utils import sample_sentences, calculate_sequence_probability
from LLM_MH import target_probability
import torch
import wandb
from advertiser_prompts import prompts_with_advertisers
import argparse
from static_mechanism_payment_rule import calculate_quantities_for_agent
# from pdb import set_trace



class LLM_static:
    def __init__(self, reference_LLM, reference_tokenizer, 
                target_distribution, 
                device,
                max_length = 100, top_k=50,top_p=0.95,
                seed_value = None):
        

        # set the seed value
        if seed_value is not None:
            np.random.seed(seed_value)  # Numpy module.
            torch.manual_seed(seed_value)  # PyTorch random number generator.
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed_value)  # CUDA random number generator.

        self.reference_LLM = reference_LLM
        self.reference_tokenizer = reference_tokenizer
        self.target_distribution = target_distribution
        self.max_length = max_length
        self.top_k = top_k
        self.top_p = top_p
        self.generated_samples = None                    # The generated samples that the static algorithm will generate
        self.generated_samples_decoded = None            # The decoded generated samples so that they are human readable
        self.generated_samples_probability_dicts = None  # Their probability dictionaries

        self.sample_probability_dicts = [] 
        self.sample_target_log_probabilities = [] 
        self.sample_target_probabilities = []
        self.sample_target_normalized_token_length_log_probabilities = []
        self.sample_target_normalized_byte_length_log_probabilities = []
        self.device = device
        



    def generate_samples(self, user_input_sequence, input_expansion, number_of_samples):
        """
        Generate samples using the reference LLM and the target distribution.

        Args:
        - user_input_sequence (str): The input text sequence.
        - advertisement_prompts (list): A list of advertisement prompts.
        - number_of_samples (int): The number of samples to be generated.
        - model_name (str): The name of the model to be used.
        - wandb_logging (bool): Whether to log the samples to Weights and Biases.

        Returns:
        - list: A list of generated samples, decoded
        - list: A list of generated samples, encoded
        """
        # Generate samples using the reference LLM
        self.generated_samples = sample_sentences(model = self.reference_LLM, tokenizer = self.reference_tokenizer, input_sequence = input_expansion + user_input_sequence, 
                                        max_length = self.max_length, top_k=self.top_k,top_p=self.top_p,num_return_sequences=number_of_samples, print_output= False, device= self.device)
        
        # Decode the generated samples
        self.generated_samples_decoded = [self.reference_tokenizer.decode(sample, skip_special_tokens=True) for sample in self.generated_samples]
        
        return self.generated_samples_decoded, self.generated_samples
    
    
    def evaluate_samples(self):
        """
        Evaluate the generated samples using the target distribution.
        """
        # Evaluate the generated samples using the target distribution
        for i in range(len(self.generated_samples)):
            probabilities_dict = target_distribution.evaluate(sequence = self.generated_samples[i], encoded_sequence= True, current_sequence= None)
            self.sample_probability_dicts.append(probabilities_dict)
            self.sample_target_log_probabilities.append(probabilities_dict['target_log_prob'])
            self.sample_target_probabilities.append(probabilities_dict['target_prob'])
            self.sample_target_normalized_token_length_log_probabilities.append(probabilities_dict['target_log_prob'] / max(1, probabilities_dict['number_of_tokens']) )
            self.sample_target_normalized_byte_length_log_probabilities.append(probabilities_dict['target_log_prob'] / max(1,len(self.generated_samples_decoded[i])))

    
    def calculate_payments(self):
        """
        Calculate the payments for the generated samples using the Myerson formula from the paper, along with all offsets.  
        """

        # Calculate the payments for the generated samples using the Myerson formula
        advertiser_rewards_all_rounds = []  # 2-D Array: Each row corresponds to a sample, and each column corresponds to an advertiser
        advertiser_rewards_all_rounds_unweighted = []
        advertiser_payments_all_rounds_no_offset = [] # 2-D Array: Each row corresponds to a round, and each column corresponds to an advertiser's payment in that round 
        advertiser_payments_all_rounds_zero_bid_offset = []
        advertiser_payments_all_rounds_not_participating_offset = []  
        advertiser_payments_all_rounds_both_offsets = [] 
        
        advertiser_expected_values_all_rounds = [] # 2-D Array: Each row corresponds to a round, and each column corresponds to an advertiser's expected value for the outcome = reward_j * prob_j 
        advertiser_utility_gain_all_rounds_no_offset = [] 
        advertiser_utility_gain_all_rounds_zero_bid_offset = []
        advertiser_utility_gain_all_rounds_not_participating_offset = []
        advertiser_utility_gain_all_rounds_both_offsets = []

        advertiser_zero_bid_offset_all_rounds = [] 
        advertiser_not_participating_offset_all_rounds = []
        advertiser_both_offsets_all_rounds = []

        advertiser_participating_value_gain_all_rounds = [] # 2-D Array: Each row corresponds to a round, and each column corresponds to an advertiser's value gain for participating in the auction
        advertiser_unweighted_expected_value_all_rounds = []
        
        reference_llm_log_probabilities = [] # 1-D Array: Each element corresponds to the probability of the reference LLM generating the sample
        proposal_distribution_log_probabilities = [] # 1-D Array: Each element corresponds to the probability of the proposal distribution generating the sample
        
        sample_probabilities_all_rounds = [] # 2-D List: Each row corresponds to a round, and each column corresponds to the probability of each sample available up to that round
        
        for round in range(len(self.generated_samples)):
            probabilities_dict = self.sample_probability_dicts[round]
           
            advertiser_rewards_current_sample = probabilities_dict['advertiser_rewards_weighted']
            advertiser_rewards_current_sample_unweighted = probabilities_dict['advertiser_rewards_unweighted']
            advertiser_rewards_all_rounds.append(advertiser_rewards_current_sample)
            advertiser_rewards_all_rounds_array = np.array(advertiser_rewards_all_rounds)

            advertiser_rewards_all_rounds_unweighted.append(advertiser_rewards_current_sample_unweighted)
            advertiser_rewards_all_rounds_unweighted_array = np.array(advertiser_rewards_all_rounds_unweighted)
            tau = probabilities_dict['tau']

            reference_llm_log_probability = probabilities_dict['reference_llm_log_prob']
            proposal_log_probability = probabilities_dict['proposal_distribution_log_prob']
            reference_llm_log_probabilities.append(reference_llm_log_probability)
            proposal_distribution_log_probabilities.append(proposal_log_probability)
            reference_llm_log_probability_array = np.array(reference_llm_log_probabilities)
            proposal_distribution_log_probability_array = np.array(proposal_distribution_log_probabilities)

            probabilities_scales = reference_llm_log_probability_array - proposal_distribution_log_probability_array  # this is the scale factor for the reference and proposal LLMs

            current_round_payments_no_offset = []
            current_round_payments_zero_bid_offset = []
            current_round_payments_not_participating_offset = []
            current_round_payments_both_offsets = []

            current_round_utility_gain_no_offset = []
            current_round_utility_gain_zero_bid_offset = []
            current_round_utility_gain_not_participating_offset = []
            current_round_utility_gain_both_offsets = []

            current_round_zero_bid_offset = []
            current_round_not_participating_offset = []
            current_round_both_offsets = []
            
            current_round_expected_rewards = []
            current_round_participating_value_gain = []
            current_round_unweighted_expected_value = []
            
            for advertiser in range(len(advertiser_rewards_current_sample)):
                # Reshape the advertiser_rewards to be a 2-d array of shape (number of advertisers,number of samples)
                advertiser_rewards_all_rounds_array_reshaped = advertiser_rewards_all_rounds_array.T
                advertiser_rewards_all_rounds_unweighted_array_reshaped = advertiser_rewards_all_rounds_unweighted_array.T

                (probabilities, utility_i_non_offset, expected_value_i, utility_i_zero_bids,  expected_externality_i, expected_value_i_without_i_same_sentences, expected_value_improvement_i_same_sentences,
                    estimated_value_i_not_participating, estimated_value_i_not_participating_approximate, true_expected_value_i_without_i_participating) = calculate_quantities_for_agent(all_rewards= advertiser_rewards_all_rounds_array_reshaped, agent_index= advertiser, 
                            probability_rescales= probabilities_scales, probability_dicts= self.sample_probability_dicts[:round + 1], tau= tau)
                
                # Calculate payments for all the different offsets
                
                payment_no_offset = expected_value_i - utility_i_non_offset
                payment_zero_bid_offset = expected_value_i - utility_i_non_offset + utility_i_zero_bids
                payment_not_participating_offset = expected_value_i - utility_i_non_offset - estimated_value_i_not_participating
                payment_both_offsets = expected_value_i - utility_i_non_offset + utility_i_zero_bids - estimated_value_i_not_participating

                # Calculate the expected gain from participating in the auction for all offsets, defined as: 
                # (expected outcome value of participating - expected outcome value of not participating) - payment for each offset
                value_gain = expected_value_i - estimated_value_i_not_participating

                unweighted_expected_value = np.sum(advertiser_rewards_all_rounds_unweighted_array_reshaped[advertiser] * probabilities)

                utility_gain_no_offset = value_gain - payment_no_offset
                utility_gain_zero_bid_offset = value_gain - payment_zero_bid_offset
                utility_gain_not_participating_offset = value_gain - payment_not_participating_offset
                utility_gain_both_offsets = value_gain - payment_both_offsets
                
                # Append everything to appropriate lists 

                current_round_payments_no_offset.append(payment_no_offset)
                current_round_payments_zero_bid_offset.append(payment_zero_bid_offset)
                current_round_payments_not_participating_offset.append(payment_not_participating_offset)
                current_round_payments_both_offsets.append(payment_both_offsets)

                current_round_zero_bid_offset.append(utility_i_zero_bids)
                current_round_not_participating_offset.append(- estimated_value_i_not_participating)
                current_round_both_offsets.append(utility_i_zero_bids - estimated_value_i_not_participating)

                current_round_expected_rewards.append(expected_value_i)
                current_round_participating_value_gain.append(value_gain)
                current_round_unweighted_expected_value.append(unweighted_expected_value)

                current_round_utility_gain_no_offset.append(utility_gain_no_offset)
                current_round_utility_gain_zero_bid_offset.append(utility_gain_zero_bid_offset)
                current_round_utility_gain_not_participating_offset.append(utility_gain_not_participating_offset)
                current_round_utility_gain_both_offsets.append(utility_gain_both_offsets)

            
            sample_probabilities_all_rounds.append(probabilities)

            advertiser_payments_all_rounds_no_offset.append(current_round_payments_no_offset)
            advertiser_payments_all_rounds_zero_bid_offset.append(current_round_payments_zero_bid_offset)
            advertiser_payments_all_rounds_not_participating_offset.append(current_round_payments_not_participating_offset)
            advertiser_payments_all_rounds_both_offsets.append(current_round_payments_both_offsets)

            advertiser_expected_values_all_rounds.append(current_round_expected_rewards)
            advertiser_participating_value_gain_all_rounds.append(current_round_participating_value_gain)
            advertiser_unweighted_expected_value_all_rounds.append(current_round_unweighted_expected_value)

            advertiser_zero_bid_offset_all_rounds.append(current_round_zero_bid_offset)
            advertiser_not_participating_offset_all_rounds.append(current_round_not_participating_offset)
            advertiser_both_offsets_all_rounds.append(current_round_both_offsets)

            advertiser_utility_gain_all_rounds_no_offset.append(current_round_utility_gain_no_offset)
            advertiser_utility_gain_all_rounds_zero_bid_offset.append(current_round_utility_gain_zero_bid_offset)
            advertiser_utility_gain_all_rounds_not_participating_offset.append(current_round_utility_gain_not_participating_offset)
            advertiser_utility_gain_all_rounds_both_offsets.append(current_round_utility_gain_both_offsets)

        self.advertiser_payments_all_rounds_no_offset = np.array(advertiser_payments_all_rounds_no_offset)
        self.advertiser_payments_all_rounds_zero_bid_offset = np.array(advertiser_payments_all_rounds_zero_bid_offset)
        self.advertiser_payments_all_rounds_not_participating_offset = np.array(advertiser_payments_all_rounds_not_participating_offset)
        self.advertiser_payments_all_rounds_both_offsets = np.array(advertiser_payments_all_rounds_both_offsets)

        self.advertiser_utility_gain_all_rounds_no_offset = np.array(advertiser_utility_gain_all_rounds_no_offset)
        self.advertiser_utility_gain_all_rounds_zero_bid_offset = np.array(advertiser_utility_gain_all_rounds_zero_bid_offset)
        self.advertiser_utility_gain_all_rounds_not_participating_offset = np.array(advertiser_utility_gain_all_rounds_not_participating_offset)
        self.advertiser_utility_gain_all_rounds_both_offsets = np.array(advertiser_utility_gain_all_rounds_both_offsets)

        self.advertiser_zero_bid_offset_all_rounds = np.array(advertiser_zero_bid_offset_all_rounds)
        self.advertiser_not_participating_offset_all_rounds = np.array(advertiser_not_participating_offset_all_rounds)
        self.advertiser_both_offsets_all_rounds = np.array(advertiser_both_offsets_all_rounds)

        self.advertiser_expected_values_all_rounds = np.array(advertiser_expected_values_all_rounds)
        self.advertiser_participating_value_gain_all_rounds = np.array(advertiser_participating_value_gain_all_rounds)
        self.advertiser_unweighted_expected_value_all_rounds = np.array(advertiser_unweighted_expected_value_all_rounds)
        self.sample_probabilities_all_rounds = sample_probabilities_all_rounds
        self.all_advertiser_rewards_all_rounds = advertiser_rewards_all_rounds_array.T
        

    
    def draw_sample(self, samples_to_consider, sampling_method, account_for_proposal_distribution): 
        """
        Return a sample from the generated samples using the specified sampling method.

        Args:
        - samples_to_consider (int): The number of samples to consider.
        - sampling_method (str): The sampling method to be used.
            Possible values:
            - 'proper: Sample proportional to the target probability, adjusting for the proposal distribution that generated the samples  (i.e., the principled approach)
            - 'target_probability': Sample proportional to the target probability  (i.e., the principeld approach)
            - 'target_log_probability': Sample proportional to the log of the target probability (hopefully more stable)
            - 'target_normalized_token_length_log_probability': Sample proportional to the log of the target probability normalized by the number of tokens in the sequence (hopefully leads to more human-like responses)
            - 'target_normalized_byte_length_log_probability': Sample proportional to the log of the target probability normalized by the number of bytes in the sequence (hopefully leads to more human-like responses)
            - 'greedy': Sample the one with the highest (log) probability (i.e., the greedy approach)

        - account_for_proposal_distribution (bool): Whether to account for the proposal distribution that generated the samples when calculating the probabilities.


        Returns:
        - str: The selected sample based on the specified sampling method.
        - dict: The probability dictionary of the selected sample, with all metrics for logging. 
        - list: The list of the draw probabilities of the samples considered, in the order they were considered.
        """
        if samples_to_consider is None: 
            samples_to_consider = len(self.generated_samples)

            # a) proportional to the target probability (i.e., the principeld approach)
            # b) proportional to the log of the target probability (hopefully more stable)
            # c) proportional to the log of the target probability normalized by the number of tokens in the sequence (hopefully leads to more human-like responses)
            # d) the one with the highest (log) probability (i.e., the greedy approach)

        if sampling_method == 'target_probability':
            if not account_for_proposal_distribution:
                unscaled_probabilities = self.sample_target_probabilities[:samples_to_consider]
            else:
                unscaled_probabilities = self.sample_probabilities_all_rounds[samples_to_consider - 1]
        elif sampling_method == 'target_log_probability':
            if not account_for_proposal_distribution:
                unscaled_probabilities = self.sample_target_log_probabilities[:samples_to_consider]
            else:
                unscaled_probabilities = np.log(self.sample_probabilities_all_rounds[samples_to_consider - 1])
        elif sampling_method == 'target_normalized_token_length_log_probability':
            unscaled_probabilities = self.sample_target_normalized_token_length_log_probabilities[:samples_to_consider]
        elif sampling_method == 'target_normalized_byte_length_log_probability':
            unscaled_probabilities = self.sample_target_normalized_byte_length_log_probabilities[:samples_to_consider]
        elif sampling_method == 'greedy':
            unscaled_probabilities = np.zeros(samples_to_consider)
            unscaled_probabilities[np.argmax(self.sample_target_log_probabilities[:samples_to_consider])] = 1
        else:
            raise ValueError(f'The specified sampling method {sampling_method} is not supported. Please choose one of the following: "target_probability", "target_log_probability", "target_normalized_token_length_log_probability", "target_normalized_byte_length_log_probability", "greedy".')


        
        # Scale the probabilities
        if np.sum(unscaled_probabilities) == 0:
            scaled_probabilities = np.ones(samples_to_consider) / samples_to_consider
        else:
            scaled_probabilities = unscaled_probabilities / np.sum(unscaled_probabilities)

        # Check that the scaled probabilities do not contain NaNs
        if np.isnan(scaled_probabilities).any():
            # set_trace()
            print('scaled_probabilities:', scaled_probabilities)
            raise ValueError('The scaled probabilities contain NaNs.')
        
        sample_index = np.random.choice(samples_to_consider, p = scaled_probabilities)
        
        decoded_sample = self.generated_samples_decoded[sample_index]
        sample_probability_dict = self.sample_probability_dicts[sample_index]


        return decoded_sample, sample_probability_dict, scaled_probabilities


        



if __name__ == "__main__":
    # --- Read seed and problem instance index from command line --- #
    parser = argparse.ArgumentParser(description='Run the Flan model on a given problem instance.')
    parser.add_argument('--seed', type=int, default= 42, help='The seed value to be used for reproducibility.')
    parser.add_argument('--problem_instance_index', type=int, default= 0, help='The index of the problem instance to be used.')
    parser.add_argument('--number_of_agents', type=int, default= 2, help='The number of agents in the problem instance.')
    parser.add_argument('--input_expansion', type=str, default= 'true', help='Whether to use input expansion or not.')

    args = parser.parse_args()
    seed_value = args.seed
    problem_instance_index = args.problem_instance_index
    number_of_agents = args.number_of_agents
    use_input_expnsion = args.input_expansion
    print(f"Seed value: {seed_value} and problem instance index: {problem_instance_index}")

    # seed_value = 50
    # problem_instance_index = 0


    
    number_of_samples = 20
    wandb_logging = True
    max_length = 700
    top_k = 50
    top_p = 0.95


    
    # model_name = "google/flan-t5-large"
    # model_name = "google/flan-t5-xxl"
    model_name = "meta-llama/Llama-2-7b-chat-hf"
    # model_name = "meta-llama/Llama-2-70b-chat-hf"   # WARNING - This model is very large!!!

    
    # --- Creating the problem instance --- #
    
    problem_instance = prompts_with_advertisers[problem_instance_index]

    # -- User input sequence/prompt, e.g. "How do  you make cookies?" -- #
    user_input_sequence = problem_instance['prompt']
    # -- Advertisement prompts, e.g. [" Answer the question advertising KitchenFix, a company that makes kitchen appliances.", " Answer the question advertising Easybake, a company that produces baking ingredients."] -- #
    advertiser_names = [advertiser['name'] for advertiser in problem_instance['advertisers']][:number_of_agents] 
    advertiser_description = [advertiser['description'] for advertiser in problem_instance['advertisers']][:number_of_agents]
    advertisement_prompts = [f"Answer the question advertising {advertiser['name']}, {advertiser['description']}: " for advertiser in problem_instance['advertisers']][:number_of_agents]
    # -- Input Expansion, e.g. "Answer the query. Mention KitchenFix, who makes kitchen appliances, and EasyBake, who produces baking ingredients. " -- # motivates the use of the advertisers in the response
    if use_input_expnsion.lower() == 'true':
        input_expansion = f"Answer the query. Mention {advertiser_names[0]}, {advertiser_description[0]} and {advertiser_names[1]}, {advertiser_description[1]}. "
    else:
        print('Not using input expansion!')
        input_expansion = ""



    # --- Load model and tokenizer --- #

    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print('----> Using GPU!!!')
        print(f'Currently using GPU: {torch.cuda.get_device_name(0)}')


    # load the model on the GPU (or CPU if you don't have a GPU)
    if 'flan' in model_name:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    elif 'gpt' in model_name:
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    elif 'llama' in model_name: 
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    else:
        raise ValueError(f'The specified model {model_name} is not supported. Please choose one of the following: "google/flan-t5-large", "google/flan-t5-xxl", "meta-llama/Llama-2-7b-chat-hf".')


    
    # Create target probability distribution
    target_distribution = target_probability(reference_LLM = model, reference_tokenizer = tokenizer, user_prompt = user_input_sequence,
        advertiser_prompts = advertisement_prompts, advertiser_names= advertiser_names, advertiser_cardinal_bids = [1,0],  # Change from 1 to 5 for the version 0.66, and to [1,2] for version 0.67 , and [1,0] for version 0.68
        proposal_expansion= input_expansion, device= device, tau = 1, remove_start_token = False) # NOTE: remove_start_token should depend on the model: True for GPT, False for T5
    

    # Create the sampler class 
    sampler = LLM_static(reference_LLM = model,
                reference_tokenizer = tokenizer,
                target_distribution = target_distribution,
                max_length = max_length, 
                top_k=top_k,
                top_p=top_p,
                device = device,
                seed_value = seed_value)
    
    
    # Generate possible sequences
    sampler.generate_samples(user_input_sequence= user_input_sequence, 
                                                        input_expansion= input_expansion,
                                                        number_of_samples= number_of_samples)
    
    
    # Evaluate target probability of all replies
    sampler.evaluate_samples()

    # Calculate payments (according to RLHF-insipred aggregation), expected utilities, expected probability of drawing each sentence etc.  
    sampler.calculate_payments()


    # --- Sample replies using all sampling methods --- #
    # for sampling_method in ['target_probability', 'target_log_probability', 'target_normalized_token_length_log_probability', 'target_normalized_byte_length_log_probability', 'greedy']:
    for sampling_method in ['target_probability']:
    # for sampling_method in ['target_normalized_byte_length_log_probability']:
        print(f'-------------------------------------------> Sampling method: {sampling_method}')

        if wandb_logging:
            # Initialize weights and biases
            initialization_dict = {
                'model_name': model_name,
                'user_input_sequence': user_input_sequence,
                'advertisement_prompts': advertisement_prompts,
                'use input_expansion': use_input_expnsion,
                'seed': seed_value,
                'problem_instance_index': problem_instance_index,
                'sampling_method': sampling_method,
                'number_of_agents': number_of_agents, 
            }

            wandb.init(project="LLM-Static-v0.68", config=initialization_dict)

            # Define all metrics for wandb logging
            wandb.define_metric("samples used")
            wandb.define_metric("sequence", step_metric= "samples used")
            wandb.define_metric("sequence probability", step_metric= "samples used")
            wandb.define_metric("sequence log probability", step_metric= "samples used")
            wandb.define_metric("number of tokens", step_metric= "samples used")
            wandb.define_metric("number of bytes", step_metric= "samples used")
            wandb.define_metric("number of characters", step_metric= "samples used")
            wandb.define_metric("sequence character normalized log probability", step_metric= "samples used")
            wandb.define_metric("sequence token normalized log probability", step_metric= "samples used")
            wandb.define_metric("sequence character normalized probability", step_metric= "samples used")
            wandb.define_metric("sequence token normalized probability", step_metric= "samples used")
            wandb.define_metric("reference llm probability", step_metric= "samples used")
            wandb.define_metric("reference llm log probability", step_metric= "samples used")
            for i in range(len(advertisement_prompts)):
                wandb.define_metric(f"advertiser {i} mentioned", step_metric= "samples used")
                wandb.define_metric(f"advertiser {i} expected value", step_metric= "samples used")
                wandb.define_metric(f"advertiser {i} unweighted expected value", step_metric= "samples used")
                
                if sampling_method == 'target_probability':
                    # --- Payment metrics --- #
                    wandb.define_metric(f"advertiser {i} payment no offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} payment zero bid offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} payment not participating offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} payment both offsets", step_metric= "samples used")

                    # --- Offset metrics --- #
                    wandb.define_metric(f"advertiser {i} zero bid offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} not participating offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} both offsets", step_metric= "samples used")

                    # --- Utility gain from Participating Metrics --- #
                    wandb.define_metric(f"advertiser {i} utility gain no offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} utility gain zero bid offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} utility gain not participating offset", step_metric= "samples used")
                    wandb.define_metric(f"advertiser {i} utility gain both offsets", step_metric= "samples used")

                    # --- Other Value Metrics --- #
                    wandb.define_metric(f"advertiser {i} participating value gain", step_metric= "samples used")

            if sampling_method == 'target_probability':
                wandb.define_metric("total payment no offset", step_metric= "samples used")
                wandb.define_metric("total payment zero bid offset", step_metric= "samples used")
                wandb.define_metric("total payment not participating offset", step_metric= "samples used")
                wandb.define_metric("total payment both offsets", step_metric= "samples used")
                wandb.define_metric("total advertiser participating value gain", step_metric= "samples used") 

            wandb.define_metric("total advertisers mentioned", step_metric= "samples used")

            wandb.define_metric("total advertiser expected value gain", step_metric= "samples used")
            wandb.define_metric("total advertiser expected value", step_metric= "samples used")




        for samples_to_consider in range(1, number_of_samples + 1):
            print('Current number of sampels to consider: ', samples_to_consider)
            sample, probabilities_dict, draw_probabilities = sampler.draw_sample(samples_to_consider= samples_to_consider, account_for_proposal_distribution= True, sampling_method= sampling_method)
            print(f'If we consider {samples_to_consider} samples:')
            print(sample)
            # print(probabilities_dict)
            print('-'*100) 

            if wandb_logging:
                wandb_dictionary = {
                    "samples used": samples_to_consider,
                    "sequence": sample,
                    "sequence probability": probabilities_dict['target_prob'],
                    "sequence log probability": probabilities_dict['target_log_prob'],
                    "number of tokens": probabilities_dict['number_of_tokens'],
                    "number of bytes": probabilities_dict['number_of_bytes'],
                    "number of characters": probabilities_dict['number_of_characters'],
                    "sequence character normalized log probability": probabilities_dict['target_log_prob'] / probabilities_dict['number_of_characters'],
                    "sequence token normalized log probability": probabilities_dict['target_log_prob'] / probabilities_dict['number_of_tokens'],
                    "sequence character normalized probability": probabilities_dict['target_prob'] / probabilities_dict['number_of_characters'],
                    "reference llm probability": probabilities_dict['reference_llm_prob'],
                    "reference llm log probability": probabilities_dict['reference_llm_log_prob'],   
                }

                
                if sampling_method == 'target_probability':
                    for i in range(len(advertisement_prompts)):
                        wandb_dictionary[f"advertiser {i} mentioned"] = probabilities_dict[f"advertiser {i} mentioned"]
                        
                        # --- Payment metrics --- #
                        wandb_dictionary[f"advertiser {i} payment no offset"] = sampler.advertiser_payments_all_rounds_no_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} payment zero bid offset"] = sampler.advertiser_payments_all_rounds_zero_bid_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} payment not participating offset"] = sampler.advertiser_payments_all_rounds_not_participating_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} payment both offsets"] = sampler.advertiser_payments_all_rounds_both_offsets[samples_to_consider - 1, i]

                        # --- Offset metrics --- #
                        wandb_dictionary[f"advertiser {i} zero bid offset"] = sampler.advertiser_zero_bid_offset_all_rounds[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} not participating offset"] = sampler.advertiser_not_participating_offset_all_rounds[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} both offsets"] = sampler.advertiser_both_offsets_all_rounds[samples_to_consider - 1, i]

                        # --- Utility gain from Participating Metrics --- #
                        wandb_dictionary[f"advertiser {i} participating value gain"] = sampler.advertiser_participating_value_gain_all_rounds[samples_to_consider - 1, i]
                        
                        
                        wandb_dictionary[f"advertiser {i} utility gain no offset"] = sampler.advertiser_utility_gain_all_rounds_no_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} utility gain zero bid offset"] = sampler.advertiser_utility_gain_all_rounds_zero_bid_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} utility gain not participating offset"] = sampler.advertiser_utility_gain_all_rounds_not_participating_offset[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} utility gain both offsets"] = sampler.advertiser_utility_gain_all_rounds_both_offsets[samples_to_consider - 1, i]

                        # --- Other Value Metrics --- #
                        wandb_dictionary[f"advertiser {i} expected value"] = sampler.advertiser_expected_values_all_rounds[samples_to_consider - 1, i]
                        wandb_dictionary[f"advertiser {i} unweighted expected value"] = sampler.advertiser_unweighted_expected_value_all_rounds[samples_to_consider - 1, i]

                        
                    wandb_dictionary["total payment no offset"] = np.sum(wandb_dictionary[f"advertiser {i} payment no offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total payment zero bid offset"] = np.sum(wandb_dictionary[f"advertiser {i} payment zero bid offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total payment not participating offset"] = np.sum(wandb_dictionary[f"advertiser {i} payment not participating offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total payment both offsets"] = np.sum(wandb_dictionary[f"advertiser {i} payment both offsets"] for i in range(len(advertisement_prompts)))

                    wandb_dictionary["total advertiser utility gain no offset"] = np.sum(wandb_dictionary[f"advertiser {i} utility gain no offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total advertiser utility gain zero bid offset"] = np.sum(wandb_dictionary[f"advertiser {i} utility gain zero bid offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total advertiser utility gain not participating offset"] = np.sum(wandb_dictionary[f"advertiser {i} utility gain not participating offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total advertiser utility gain both offsets"] = np.sum(wandb_dictionary[f"advertiser {i} utility gain both offsets"] for i in range(len(advertisement_prompts)))

                    wandb_dictionary["total advertiser participating value gain"] = np.sum(wandb_dictionary[f"advertiser {i} participating value gain"] for i in range(len(advertisement_prompts)))

                    wandb_dictionary["total zero bid offset"] = np.sum(wandb_dictionary[f"advertiser {i} zero bid offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total not participating offset"] = np.sum(wandb_dictionary[f"advertiser {i} not participating offset"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary["total both offsets"] = np.sum(wandb_dictionary[f"advertiser {i} both offsets"] for i in range(len(advertisement_prompts)))

                    wandb_dictionary['total advertisers mentioned'] = np.sum(wandb_dictionary[f"advertiser {i} mentioned"] for i in range(len(advertisement_prompts)))
                    wandb_dictionary['total advertiser expected value'] = np.sum(wandb_dictionary[f"advertiser {i} expected value"] for i in range(len(advertisement_prompts)))

                        
                
                elif sampling_method != 'target_probability':  # NOTE: for the target probability method, get all of these from the calculate payments field. 
                    continue
                    # Get the rewards for the advertisers up to this round
                    rewards = sampler.all_advertiser_rewards_all_rounds[:, :samples_to_consider]
                    # Set for each advertiser the lowest reward to zero 
                    rewards = rewards - np.min(rewards, axis=1)[:, np.newaxis]
                    # Calculate the expected value for each advertiser
                    expected_values = rewards @ draw_probabilities

                    for i in range(len(advertisement_prompts)):
                        wandb_dictionary[f"advertiser {i} mentioned"] = probabilities_dict[f"advertiser {i} mentioned"]
                        wandb_dictionary[f"advertiser {i} expected value"] = expected_values[i]
                        
                        if sampling_method == 'target_probability':
                            wandb_dictionary[f"advertiser {i} payment"] = sampler.advertiser_payments_all_rounds[samples_to_consider - 1, i]
                            wandb_dictionary[f"advertiser {i} expected utility"] = sampler.advertiser_expected_utilities_all_rounds[samples_to_consider - 1, i]
                        
                        wandb_dictionary[f"advertiser {i} expected value"] = expected_values[i]
                        wandb_dictionary[f"advertiser expected welfare"] = np.sum(expected_values)

                    if sampling_method == 'target_probability':
                        wandb_dictionary["total payments"] = np.sum(sampler.advertiser_payments_all_rounds[samples_to_consider - 1, :])
                    
                    wandb_dictionary["total advertisers mentioned"] = probabilities_dict["number_of_advertisers_mentioned"]

                # set_trace()

                wandb.log(wandb_dictionary)

        if wandb_logging:
            wandb.finish()

                

    


# self.advertiser_payments_all_rounds = np.array(advertiser_payments_all_rounds)
# self.advertiser_expected_values_all_rounds = np.array(advertiser_expected_values_all_rounds)
# self.advertiser_expected_utilities_all_rounds = np.array(advertiser_expected_utilities_all_rounds)
# self.sample_probabilities_all_rounds = sample_probabilities_all_rounds
    
    


    # # print all replies 
    # samples = []
    # for i in range(len(original_output_sequences)):
    #     print('-'*100)
    #     sample = tokenizer.decode(original_output_sequences[i], skip_special_tokens=True)
    #     samples.append(sample)
    #     print(f'Sample {i}: ', sample)

    # set_trace()

    # # Evaluate target probability of all replies
    # target_probabilities = []
    # target_log_probabilities = []
    # target_normalized_token_length_log_probabilities = []
    # target_normalized_byte_length_log_probabilities = []
    # sample_probability_dicts = []
    # for i in range(len(samples)):
    #     probabilities_dict = target_distribution.evaluate(sequence = samples[i], current_sequence= None)
    #     sample_probability_dicts.append(probabilities_dict)
    #     target_log_probabilities.append(probabilities_dict['target_log_prob'])
    #     target_probabilities.append(probabilities_dict['target_prob'])
    #     target_normalized_token_length_log_probabilities.append(probabilities_dict['target_normalized_token_length_log_prob'] / probabilities_dict['number_of_tokens'])
    #     target_normalized_byte_length_log_probabilities.append(probabilities_dict['target_normalized_byte_length_log_prob'] / len(samples[i]))

    
    # # For any number of samples in the range that you created 
    # # Sample a sentence from the set of samples: 
    # # a) proportional to the target probability (i.e., the principeld approach)
    # # b) proportional to the log of the target probability (hopefully more stable)
    # # c) proportional to the log of the target probability normalized by the number of tokens in the sequence (hopefully leads to more human-like responses)
    # # d) the one with the highest (log) probability (i.e., the greedy approach)

    # for max_sample_number in range(len(samples)):
    #     # a) proportional to the target probability
    #     sample_index = np.random.choice([i for i in range(max_sample_number)], p = target_probabilities[:max_sample_number])
    #     # # b) proportional to the log of the target probability
    #     # sample_index = np.random.choice(samples[:max_sample_number], p = target_log_probabilities[:max_sample_number])
    #     # # c) proportional to the log of the target probability normalized by the number of tokens in the sequence
    #     # sample_index = np.random.choice(samples[:max_sample_number], p = target_normalized_token_length_log_probabilities[:max_sample_number])
        
        
    #     sample_prob = target_probabilities[sample_index]
    #     sample_log_prob = target_log_probabilities[sample_index]
    #     sample_normalized_token_log_prob = target_normalized_token_length_log_probabilities[sample_index]

        

    
    