import numpy as np

def unsummarized_history(hist):
    prompt = f"""So far you have interacted with {len(hist)} users. Here is the data you have collected:\n"""
    for i in range(len(hist)):
        prompt += f"""User {i} saw advertisement {hist[i][0]} """
        if hist[i][1] == 1:
            prompt += f"""and clicked"""
        else:
            prompt += f"""but did not click"""
        prompt += "\n"
    return prompt

def summarized_history(hist,options,k):
    prompt = f"""So far you have interacted with {len(hist)} users. Here is a summary of the data you have collected:\n"""
    for i in range(k):
        counts = len([x for x in hist if x[0] == options[i]])
        if counts > 0:
            rews = np.sum([x[1] for x in hist if x[0] == options[i]])
            prompt += f"""Advertisement {options[i]} was shown to {counts} users with an estimated click rate of {rews/counts:.2f}\n"""
        else:
            prompt += f"""Advertisement {options[i]} has not been shown\n"""
    return prompt

class AdvertsPrompt(object):
    def __init__(self,T,K,suggestive=False,summarized=False,cot=False,dist=False):
        self.T=T
        self.K=K
        self.suggestive = suggestive
        self.summarized = summarized
        self.cot = cot
        self.dist = dist
        self.options = ['A', 'B', 'C', 'D', 'E', 'F', 'G']

        ## For dist option only
        self.letters = ['n1','n2','n3','n4','n5','n6','n7']
        format_list = [f"{self.options[i]}:{self.letters[i]}" for i in range(len(self.options))]
        self.format_str = ",".join(format_list[0:self.K])

    def get_name(self):
        name = "adverts"
        name += "_sug" if self.suggestive else "_neu"
        name += "_sum" if self.summarized else "_raw"
        name += "_dist" if self.dist else "_uni"
        name += "_cotn" if self.cot else "_not"
        name += f"_K={self.K}"
        return name
        
    def get_outputs(self):
        return self.options[0:self.K]

    def get_reward_scale(self):
        return 1

    def strip(self, lst, last):
        s = ""
        for item in lst[0:-1]:
            s += item + ", "
        s += lst[-1]
        return (s)

    def get_system_text(self):
        option_str = self.strip(self.options[0:self.K], "and")
        prompt = f"""You are recommendation engine that chooses advertisements to display to users when they visit your webpage.
There are {self.K} advertisements you can choose from, named {option_str}.
When a user visits the webpage you can choose an advertisement to display and you will observe whether the user clicks on the ad or not.
You model this by assuming that each advertisement has a certain click rate and users click on advertisements with their corresponding rates.
You have a budget of {self.T} users to interact with and your goal is to maximize the total number of clicks during this process. \n\n"""
        if self.suggestive:
            prompt += f"""A good strategy to optimize for clicks in these situations requires balancing exploration and exploitation. You need to explore to try out all of the options and find those with high click rates, but you also have to exploit the information that you have to accumulate clicks.\n\n"""

        if self.summarized:
            prompt += f"""When each user visits the webpage, I will show you a summary of the data you have collected so far.\n"""
        else:
            prompt += f"""When each user visits the webpage, I will show you all of the data you have collected so far.\n"""
        option_str = self.strip(self.options[0:self.K], "or")
        
        if self.dist:
            prompt += f"""Then you must choose which advertisement to display. You may output a distribution over the {self.K} choices formatted EXACTLY like "{self.format_str}".\n\n"""
            output_prompt = f"""within the tags <Answer> DIST </Answer> where DIST is the distribution in the format specified above"""
        else:
            prompt += f"""Then you must choose which advertisement to display. This must be exactly one of {option_str}.\n\n"""
            output_prompt = f"""within the tags <Answer> ADVERTISEMENT </Answer> where ADVERTISEMENT is one of {option_str}"""

        if self.cot:
            prompt += f"""Let's think step by step to make sure we make a good choice. Then, you must provide your final answer {output_prompt}."""
        else:
            prompt += f"""You must provide your final answer {output_prompt} and with no text explanation."""
        return (prompt)
    
    def get_main_prompt(self,hist):
        prompt = ""
        if self.summarized:
            prompt += summarized_history(hist,self.options,self.K)
        else:
            prompt += unsummarized_history(hist)
        prompt += "\n"
        option_str = self.strip(self.options[0:self.K], "or")
        if self.dist:
            output_prompt = f"""<Answer> DIST </Answer> where DIST is formatted like "{self.format_str}"."""
        else:
            output_prompt = f"""<Answer> ADVERTISEMENT </Answer> where ADVERTISEMENT is one of {option_str}."""
        prompt += f"""Which advertisement will you choose next? Remember, YOU MUST provide your final answer within the tags {output_prompt}."""
        if self.cot:
            prompt += f" Let's think step by step to make sure we make a good choice."
        return (prompt)
        
    def try_direct_parse(self, pred):
        lst = [x.strip().strip('.!?:"\'').upper() for x in pred.split(" ")]
        found = []
        for option in self.options[0:self.K]:
            if option in lst:
                found.append(option)
        if len(found) == 0:
            return None
        return found[0]

    def parse_output(self,pred):
        l = pred.rfind('<Answer>')
        r = pred.rfind('</Answer>')
        if l == -1 or r == -1:
            # raise Exception("Invalid parse " + pred)
            return self.try_direct_parse(pred)
        parsed = pred[l+8:r]
        if self.dist:
            return self.parse_dist(parsed)
        parsed = parsed.strip().strip('.!?:"\'').upper()
        parsed = [x for x in parsed.split(" ") if len(x) == 1]
        found = []
        for option in self.options[0:self.K]:
            if option in parsed:
                found.append(option)
        if len(found) == 0:
            # raise Exception("Invalid parse " + pred)
            return self.try_direct_parse(pred)
        return found[0]

    def parse_dist(self,parsed):
        items = parsed.split(",")
        distr = np.zeros(self.K)
        try:
            for item in items:
                (col, val) = item.split(":")
                col = col.strip().strip(',!?:"\'').lower()
                val = float(val.strip().strip(',!?:"\''))
                if col in self.options[0:self.K]:
                    distr[self.options.index(col)] = val
            if (distr == np.zeros(self.K)).all():
                return None
            sample = np.random.multinomial(1, distr)
            ind = np.where(sample == 1)[0][0]
            return self.options[ind]
        except:
            return None
