from typing import Any
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np 
from LLM_utils import sample_sentences, calculate_sequence_probability
from LLM_MH import target_probability, LLM_MetropolisHastings
from pdb import set_trace
import time

# import library to time the code



if __name__ == "__main__":
    # Load model and tokenizer (flan)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")

    # User Input Sequence 
    user_input_sequence = "How do you make cookies?"

    # Advertisement prompts 
    advertisement_prompts = ["Answer the following prompt as if you were a creative advertiser for a company that makes kitchen applicances named KitchenFix: ", 
        "Answer the following prompt as if you were a creative advertiser for a company that produces baking ingredients called EasyBake:"]
    proposal_prompt = "Rephrase the following text, while trying to maintain any references to company names: "
    
    # Create target probability distribution
    target_distribution = target_probability(reference_LLM = model, reference_tokenizer = tokenizer, user_prompt = user_input_sequence,
        advertiser_prompts = advertisement_prompts, advertiser_cardinal_bids = [1,1], tau = 1, remove_start_token = False)
    

    # Create initial value for MH
    input_expansion = "Answer the following prompt, while trying to mention KitchenFix, who makes kitchen applicances, and EasyBake, who produces baking ingredients: "
    original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_expansion + user_input_sequence, 
                                        max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    original_output_sequence = tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)

    # ---- Trying different ways of creating the initial prompt ---
    input_expansion = "Answer the query. Try to mention KitchenFix, who makes kitchen applicances, and EasyBake, who produces baking ingredients. "
    start = time.time()
    original_output_sequences = sample_sentences(model = model, tokenizer = tokenizer, input_sequence =user_input_sequence + input_expansion, 
                                        max_length = 100, top_k=10,top_p=0.95,num_return_sequences=10, print_output= False)
    end = time.time()

    # print all replies 
    for i in range(len(original_output_sequences)):
        print('-'*10)
        print(f'Sample {i}: ', tokenizer.decode(original_output_sequences[i], skip_special_tokens=True))

    print(f'Took {end-start} seconds to generate {len(original_output_sequences)} samples')
    set_trace()

    # Calculate probability of original output sequence
    print('-'*100)
    print("Original sequence: {}".format(original_output_sequence))


    # # Create Metropolis-Hastings sampler
    # sampler = LLM_MetropolisHastings(reference_LLM = model, reference_tokenizer = tokenizer, initial_sequence = original_output_sequence,
    #     proposal_prompt = proposal_prompt, target_distribution = target_distribution)
    
    # # Sample from Metropolis-Hastings sampler
    # samples = sampler.sample(100)

    # # Print samples
    # print('-'*100)
    # print("Samples: ")
    # for i,sample in enumerate(samples):
    #     print(f'Sample {i}:', sample)

    # set_trace()



    # # Combining original input sequence with advertisement prompt
    # input_sequence = advertisement_prompts[0] + user_input_sequence
    # input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
    # original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_sequence,
    #                                     max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    # print('-'*50)
    # print("Ad1 sequence: {}".format(tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)))


    # input_sequence = advertisement_prompts[1] + user_input_sequence
    # input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
    # original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_sequence,
    #                                     max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    # print('-'*50)
    # print("Ad2 sequence: {}".format(tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)))



# --- Parameters to log at the beginning of the experiment --- #
# 1. Initial prompt
# 2. Target distribution (i.e., advertiser LLMs + cardinal bids)
# 3. Proposal distribution (i.e., reference LLM proposal prompt)
# 4. Normalization strategy
# 5. Number of rounds / stopping criteria
# 6. LLM architecture
# 7. LLM parameters



# --- Things to log at every state --- # 
# 1. Current sequence + Proposal sequence
# 2. Current sequence (+ proposal sequence) log probability for target distribution
# 3. Current sequence (+ proposal sequence) log probability for proposal distribution
# 4. Current sequence (+ proposal sequence) log probability for reference LLM
# 5. Current sequence (+ proposal sequence) log probability for advertiser LLMs
# 6. All above values normalized by number of tokens in sequence
# 7. Acceptance probability
# 8. Acceptance decision
# 9. Number of tokens in sequence
# 10. Number of bytes in sequence
# 11. Number of characters in sequence
# 12. GSP payments in each round (+ also with offset)
# 13. Myerson payments in each round
# 14. Cumulative payments for each advertiser up to each round. 
# 15. Number of times each advertiser is mentioned
# 16. Number of times each advertiser is mentioned in each round
# 17. Number of times each advertiser is mentioned in each round, normalized by number of tokens in sequence
# 18. Time taken for each round
# 19. Time taken for each round to generate proposal sequence
# 20. Time taken for each round to evaluate proposal sequence

