import numpy as np

def unsummarized_history(hist):
    prompt = f"""So far you have played {len(hist)} times with the following choices and rewards:\n"""
    for i in range(len(hist)):
        prompt += f"""{hist[i][0]} button, reward {hist[i][1]}\n"""
    return prompt

def summarized_history(hist,colors,k):
    prompt = f"""So far you have played {len(hist)} times with your past choices and rewards summarized as follows:\n"""
    for i in range(k):
        counts = len([x for x in hist if x[0] == colors[i]])
        if counts > 0:
            rews = np.sum([x[1] for x in hist if x[0] == colors[i]])
            prompt += f"""{colors[i]} button: pressed {counts} times with average reward {rews/counts:.2f}\n"""
        else:
            prompt += f"""{colors[i]} button: pressed {counts} times\n"""
    return prompt

class ButtonsPrompt(object):
    def __init__(self,T,K,suggestive=False,summarized=False,cot=False,dist=False):
        self.T=T
        self.K=K
        self.suggestive = suggestive
        self.summarized = summarized
        self.cot = cot
        self.dist = dist
        self.colors = ['blue', 'green', 'red', 'yellow', 'purple', 'brown', 'white']

        ## For dist option only
        self.letters = ['a','b','c','d','e','f','g']
        format_list = [f"{self.colors[i]}:{self.letters[i]}" for i in range(len(self.colors))]
        self.format_str = ",".join(format_list[0:self.K])

    def get_name(self):
        name = "buttons"
        name += "_sug" if self.suggestive else "_neu"
        name += "_sum" if self.summarized else "_raw"
        name += "_dist" if self.dist else "_uni"
        name += "_cotn" if self.cot else "_not"
        name += f"_K={self.K}"
        return name

    def get_outputs(self):
        return self.colors[0:self.K]

    def get_reward_scale(self):
        return 1

    def strip(self, lst, last):
        s = ""
        for item in lst[0:-1]:
            s += item + ", "
        s += lst[-1]
        return (s)

    def get_system_text(self):
        color_str = self.strip(self.colors[0:self.K], "and")
        if self.suggestive:
            prompt = f"""You are a bandit algorithm in a room with {self.K} buttons labeled {color_str}.\n"""
        else:
            prompt = f"""You are in a room with {self.K} buttons labeled {color_str}.\n"""            
        prompt += f"""Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. 
For each button, when you press it, you will get a reward that is sampled from the button's associated distribution.
You have {self.T} time steps and, on each time step, you can choose any button and receive the reward.
Your goal is to maximize the total reward over the {self.T} time steps.\n"""
        if self.summarized:
            prompt += f"""At each time step, I will show you a summary of your past choices and rewards. """
        else:
            prompt += f"""At each time step, I will show you your past choices and rewards. """
        color_str = self.strip(self.colors[0:self.K], "or")
        if self.dist:
            prompt += f"""Then you must make the next choice. You may output a distribution over the {self.K} buttons formatted EXACTLY like "{self.format_str}". """
            output_prompt = "within the tags <Answer> DIST </Answer> where DIST is the distribution in the format specified above"
        else:
            prompt += f"""Then you must make the next choice, which must be exactly one of {color_str}. """
            output_prompt = f"""within the tags <Answer> COLOR </Answer> where COLOR is one of {color_str}"""
        if self.cot:
            prompt += f"""Let's think step by step to make sure we make a good choice. You must provide your final answer {output_prompt}."""
            ## within the tags <Answer> COLOR </Answer> where COLOR is one of {color_str}."""
        else:
            prompt += f"""You must provide your final answer immediately {output_prompt} and with no text explanation."""
        return (prompt)

    def get_main_prompt(self,hist):
        prompt = ""
        if self.summarized:
            prompt += summarized_history(hist,self.colors,self.K)
        else:
            prompt += unsummarized_history(hist)
        prompt += "\n"
        color_str1 = self.strip(self.colors[0:self.K], "or")
        if self.dist:
            output_prompt = f"""<Answer> DIST </Answer> where DIST is formatted like "{self.format_str}"."""
        else:
            output_prompt = f"""<Answer> COLOR </Answer> where COLOR is one of {color_str1}."""
        prompt += f"""Which button will you choose next? Remember, YOU MUST provide your final answer within the tags {output_prompt}"""
        if self.cot:
            prompt += f" Let's think step by step to make sure we make a good choice."
        return (prompt)
        
    def parse_output(self,pred):
        l = pred.rfind('<Answer>')
        r = pred.rfind('</Answer>')
        if l == -1 or r == -1:
            # raise Exception("Invalid parse " + pred)
            lst = [x.strip().strip('.!?:"\'').lower() for x in pred.split(" ")]
            options = []
            for color in self.colors[0:self.K]:
                if color in lst:
                    options.append(color)
            if len(options) == 0:
                return None
            return options[0]
        parsed = pred[l+8:r]
        if self.dist:
            return self.parse_dist(parsed)
        parsed = parsed.strip().strip('.!?:"\'').lower()
        options = []
        for color in self.colors[0:self.K]:
            if color in parsed:
                options.append(color)
        if len(options) == 0:
            # raise Exception("Invalid parse " + pred)
            return None
        return options[0]

    def parse_dist(self,parsed):
        items = parsed.split(",")
        distr = np.zeros(self.K)
        try:
            for item in items:
                (col, val) = item.split(":")
                col = col.strip().strip(',!?:"\'').lower()
                val = float(val.strip().strip(',!?:"\''))
                if col in self.colors[0:self.K]:
                    distr[self.colors.index(col)] = val
            if (distr == np.zeros(self.K)).all():
                return None
            sample = np.random.multinomial(1, distr)
            ind = np.where(sample == 1)[0][0]
            return self.colors[ind]
        except:
            return None
                
