import PIL.GimpGradientFile
import math
import torch
from torch.utils.data import DataLoader
from torch.distributions import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
from .base import _agent

class LinES(_agent):
    """
    Linear Ensemble Sampling
    """
    def __init__(self,
                 num_arm,
                 dim_context,
                 ensemble_size,
                 reg=1.0,
                 beta=None,
                 device='cpu',
                 T = 10000,
                 name='Linear Ensemble Sampling (LinES)'):
        super(LinES, self).__init__(name)
        self.num_arm = num_arm
        self.dim_context = dim_context
        self.m = ensemble_size
        self.reg = float(reg)
        self.device = device
        self.beta = 1.0 if beta is None else float(beta)
        self.T = T
        self.clear()

    def clear(self):
        self.t = 1
        #self.beta = self.compute_beta(self.T, self.dim_context, self.reg)
        self.V_inv = (1.0 / self.reg) * torch.eye(self.dim_context, device=self.device)

        w_scale = math.sqrt(self.reg) * self.beta
        W = torch.randn(self.m, self.dim_context, device=self.device) * w_scale
        self.S = W.clone()
        self.Theta = (self.V_inv @ self.S.T).T

        self.last_cxt = None
        self.last_reward = None
        print('beta = ', str(self.beta))
        print('m = ', self.m)

    @staticmethod
    def compute_beta(T, d, lam, sigma=1.0, S=1.0, delta=0.01):
        return (sigma * math.sqrt(d * math.log(1.0 + T / (d * lam)) + 2.0 * math.log(1.0/delta))
                + math.sqrt(lam) * S)

    @torch.no_grad()
    def choose_arm(self, context):
        if not torch.is_tensor(context):
            context = torch.as_tensor(context, dtype=torch.float32, device=self.device)
        else:
            context = context.to(self.device)

        j_t = torch.randint(low=0, high=self.m, size=(1,), device=self.device).item()
        theta = self.Theta[j_t]
        scores = context @ theta
        arm_to_pull = torch.argmax(scores).item()

        # Save selected arm's context for the update
        self.last_cxt = context[arm_to_pull]
        return arm_to_pull

    def receive_reward(self, arm, context, reward):
        self.last_reward = float(reward)

    def update_model(self, num_iter=None):
        if self.last_cxt is None:
            return
        x = self.last_cxt.to(self.device).view(-1)
        y = torch.tensor(self.last_reward, device=self.device, dtype=x.dtype)

        omega = self.V_inv @ x
        denom = 1.0 + torch.dot(omega, x)
        if denom <= 1e-12:
            denom = denom + 1e-12
        self.V_inv = self.V_inv - torch.outer(omega, omega) / denom

        z = torch.randn(self.m, device=self.device) * self.beta
        incr = (y + z).unsqueeze(1) * x.unsqueeze(0)
        self.S = self.S + incr
        self.Theta = (self.V_inv @ self.S.T).T

        self.t += 1

