import torch
import numpy as np
import matplotlib.pyplot as plt 
import string
import random


class BanditEnv:
    '''
    Each arm has a reward distribution which is a torch.dist object.
    The reward distribution is fixed (unless interventions are made).
    '''

    def __init__(self, n_arms: int = 5, device: str = 'cuda'):
        assert n_arms > 1, 'number of arms must be greater than 1'
        assert n_arms < len(string.ascii_lowercase), "hey that's a lot of arms"

        self.n_arms = n_arms
        self.device = device
        self.reset_arms()
        self.controls = torch.arange(len(self.arms))
        # the default distribution of arms given control is the diagonal (categorical with mass on i)
        # this disitribution is over the labels in alphabetical order (ie. ABC...)
        self.arm_given_index = {
            i: torch.distributions.Categorical(probs=torch.eye(n_arms)[i])
            for i in range(n_arms)
        }

    def reset_arms(self):
        self.arms = {
            string.ascii_letters[i]: torch.distributions.Normal(
                torch.randn(1).to(self.device), torch.rand(1).to(self.device)
            )
            for i in range(self.n_arms)
        }
        self.arm_keys = list(self.arms.keys())

    def pull(self, arm_idx: int) -> torch.Tensor:
        arm_to_pull = self.sample_arm(arm_idx)
        return self.arms[arm_to_pull].sample()

    def sample_arm(self, arm_idx: int) -> str:
        assert arm_idx in self.controls, 'invalid control'
        arm_dist = self.arm_given_index[arm_idx]
        return self.arm_keys[arm_dist.sample().item()]
       
    def intervene(self, arm_given_index: dict[torch.distributions.Distribution]):
        ''' 
        in this envorinment, the intervention takes us to a state where the
        association between control and action is different. 
        '''
        self.arm_given_index = arm_given_index

    def display_arms(self):
        # Extract means and standard deviations
        distributions = [self.arms[arm] for arm in self.arms]
        means = [dist.mean.item() for dist in distributions]
        stds = [dist.stddev.item() for dist in distributions]

        # Determine the global y-range (covering ±3 standard deviations)
        global_min = min(mean - 3 * std for mean, std in zip(means, stds))
        global_max = max(mean + 3 * std for mean, std in zip(means, stds))

        # Create y-values
        y_values = np.linspace(global_min, global_max, num=500)

        # Compute PDFs for each distribution
        pdf_values = []
        for dist in distributions:
            y_tensor = torch.tensor(y_values).to(self.device)
            pdf = dist.log_prob(y_tensor).exp().cpu().numpy()
            pdf_values.append(pdf)

        # Find the maximum PDF value across all distributions for scaling
        max_pdf_value = max(pdf.max() for pdf in pdf_values)

        # Set maximum width for the plots
        max_width = 0.4

        # Create figure and axis
        fig, ax = plt.subplots(figsize=(8, 6))

        # For each distribution, plot the density
        for idx, pdf in enumerate(pdf_values):
            scaled_pdf = (pdf / max_pdf_value) * max_width
            x_vals_left = np.full_like(y_values, idx) - scaled_pdf
            x_vals_right = np.full_like(y_values, idx) + scaled_pdf
            ax.fill_betweenx(y_values, x_vals_left, x_vals_right, alpha=0.5)

        ax.set_xticks(range(len(distributions)))
        # set labels for the x ticks
        ax.set_xticklabels(self.arms.keys())
        ax.set_xlabel('Distribution ID')
        ax.set_ylabel('Sample Space')
        ax.set_title('Environment payouts given arm')
        plt.show()


