import PIL.GimpGradientFile
import math
import torch
from torch.utils.data import DataLoader
from torch.distributions import Binomial, Bernoulli
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
from .base import _agent

class GLMFPL(_agent):
    def __init__(self,
                 num_arm,
                 dim_context,
                 model,
                 optimizer,
                 criterion,
                 collector,
                 a,
                 batch_size=None,
                 tao=None,
                 reduce=None,
                 reg=1.0,
                 device='cpu',
                 name='GLM-FPL'):
        super(GLMFPL, self).__init__(name)
        self.tao = dim_context if tao is None else tao
        self.reduce = reduce
        self.num_arm = num_arm
        self.dim_context = dim_context

        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

        if batch_size:
            self.loader = DataLoader(collector, batch_size=batch_size)
            self.batchsize = batch_size
        else:
            self.loader = None
            self.batchsize = None

        self.collector = collector
        self.a = a
        self.reg = reg
        self.device = device
        self.step = 0

        self.clear()

    def clear(self):
        self.collector.clear()
        self.Design = 0
        self.last_cxt = 0

    @torch.no_grad()
    def choose_arm(self, context):
        if self.step <= self.tao:
            return self.step % self.num_arm
        self.model.eval()
        scores = self.model(context.to(self.device)).squeeze(dim=1)
        arm_to_pull = torch.argmax(scores).item()
        return arm_to_pull

    def receive_reward(self, arm, context, reward):
        self.collector.collect_data(context, arm, reward)
        self.last_cxt = context
        if self.step < self.tao:
            self.Design += 0.25 * context.view(-1, 1) @ context.view(-1, 1).T

    def _fpl_loss(self, pred_logits, targets, contexts):
        base = self.criterion(pred_logits, targets)
        z = torch.normal(mean=0.0, std=self.a, size=targets.shape, device=targets.device)
        perturb = -(z * pred_logits.detach() + 0.0)
        perturb = -(z * pred_logits)
        # Average to keep scales comparable with the base criterion
        return base + perturb.mean()

    def update_model(self, num_iter):
        self.step += 1
        if self.reduce and (self.step % self.reduce != 0):
            return

        for p in self.optimizer.param_groups:
            p['weight_decay'] = self.reg / max(self.step, 1)

        self.model.train()

        if self.batchsize and self.batchsize < self.step:
            ploader = sample_data(self.loader)
            for _ in range(num_iter):
                contexts, rewards = next(ploader)
                contexts = contexts.to(self.device)
                rewards = rewards.to(dtype=torch.float32, device=self.device)

                self.model.zero_grad()
                logits = self.model(contexts).squeeze(dim=1)
                loss = self._fpl_loss(logits, rewards, contexts)
                loss.backward()
                self.optimizer.step()
            assert not torch.isnan(loss), 'Loss is NaN!'
        else:
            contexts, rewards = self.collector.fetch_batch()
            contexts = torch.stack(contexts, dim=0).to(self.device)
            rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)

            for _ in range(num_iter):
                self.model.zero_grad()
                logits = self.model(contexts).squeeze(dim=1)
                loss = self._fpl_loss(logits, rewards, contexts)
                loss.backward()
                self.optimizer.step()
                if loss.item() < 1e-3:
                    break
            assert not torch.isnan(loss), 'Loss is NaN!'
