import PIL.GimpGradientFile
import math
import torch
from torch.utils.data import DataLoader
from torch.distributions import Binomial, Bernoulli
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
from .base import _agent

'''
LinPHE: Linear Perturbed-History Exploration
'''

class LinPHE(_agent):

    def __init__(self,
                 num_arm,
                 dim_context,
                 a=1.0,
                 reg=1.0,
                 fixed_arms=False,
                 device='cpu',
                 name='LinPHE'):
        super(LinPHE, self).__init__(name)
        assert a > 0, "LinPHE requires a > 0"
        self.num_arm = num_arm
        self.dim_context = dim_context
        self.a = float(a)
        self.reg = float(reg)
        self.fixed_arms = bool(fixed_arms)
        self.device = device
        self.clear()

    def clear(self):
        self.t = 1
        self.G_inv = (1.0 / (self.reg * (self.a + 1.0))) * \
                     torch.eye(self.dim_context, device=self.device)

        self.b_true = torch.zeros(self.dim_context, device=self.device)

        self._X_hist = []
        self._y_hist = []

        self._V_per_arm = torch.zeros(self.num_arm, self.dim_context, device=self.device)
        self._T_per_arm = torch.zeros(self.num_arm, dtype=torch.long, device=self.device)

        # Working buffers (last chosen context & reward)
        self.last_cxt = None
        self.last_reward = None
        self.last_arm = None

        self.theta_tilde = torch.zeros(self.dim_context, device=self.device)

    @torch.no_grad()
    def choose_arm(self, context):
        if not torch.is_tensor(context):
            context = torch.as_tensor(context, dtype=torch.float32, device=self.device)
        else:
            context = context.to(self.device)

        if self.fixed_arms and len(self._X_hist) > 0:
            T = self._T_per_arm.clamp(min=0)
            total_count = torch.ceil(self.a * T.to(torch.float32))
            U = Binomial(total_count=total_count, probs=torch.tensor(0.5, device=self.device)).sample()
            b_true_eff = self._V_per_arm.sum(dim=0)
            b_pseudo_eff = (context.T @ U)
            b_tilde = b_true_eff + b_pseudo_eff
        else:
            if len(self._X_hist) > 0:
                X_stack = torch.stack(self._X_hist, dim=0)
                a_floor = math.floor(self.a)
                a_frac = self.a - a_floor
                if a_floor > 0:
                    U_int = Binomial(total_count=a_floor,
                                     probs=torch.tensor(0.5, device=self.device)
                                     ).sample((len(self._X_hist),)).to(self.device)
                else:
                    U_int = torch.zeros(len(self._X_hist), device=self.device)

                if a_frac > 0:
                    extra = Bernoulli(torch.tensor(a_frac, device=self.device)).sample((len(self._X_hist),))
                    U_frac = Binomial(total_count=extra, probs=torch.tensor(0.5, device=self.device)).sample()
                else:
                    U_frac = torch.zeros(len(self._X_hist), device=self.device)

                U = U_int + U_frac
                b_pseudo = X_stack.T @ U
            else:
                b_pseudo = torch.zeros(self.dim_context, device=self.device)
            b_tilde = self.b_true + b_pseudo

        self.theta_tilde = self.G_inv @ b_tilde

        scores = context @ self.theta_tilde
        arm_to_pull = torch.argmax(scores).item()

        self.last_cxt = context[arm_to_pull].view(-1)
        self.last_arm = arm_to_pull
        return arm_to_pull

    def receive_reward(self, arm, context, reward):
        self.last_reward = float(reward)

    @torch.no_grad()
    def update_model(self, num_iter=None):
        if self.last_cxt is None:
            return

        x = self.last_cxt.to(self.device).view(-1)
        y = torch.tensor(self.last_reward, device=self.device, dtype=x.dtype)

        # Update b_true
        self.b_true = self.b_true + y * x

        if self.fixed_arms:
            self._V_per_arm[self.last_arm] += y * x
            self._T_per_arm[self.last_arm] += 1
        else:  # Store history for future pseudo samples
            self._X_hist.append(x.detach().clone())
            self._y_hist.append(float(y.item()))

        c = float(self.a + 1.0)
        omega = self.G_inv @ x
        denom = 1.0 + c * torch.dot(x, omega)
        if denom <= 1e-12:
            denom = denom + 1e-12
        self.G_inv = self.G_inv - (c * torch.outer(omega, omega)) / denom

        self.t += 1
