import PIL.GimpGradientFile
import torch
from torch.utils.data import DataLoader
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
from .base import _agent

from train_utils.dataset import sample_data

import copy
import math
import torch
from torch.utils.data import DataLoader
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import nn

class GLMES(_agent):

    def __init__(self,
                 num_arm: int,
                 dim_context: int,
                 model: nn.Module,
                 optimizer,
                 criterion,
                 collector,
                 ensemble_size: int,
                 var_perturb: float,
                 tao: int = None,
                 reg: float = 1.0,
                 reduce: int = None,
                 device: str = 'cpu',
                 name: str = 'GLM-ES'):
        super(GLMES, self).__init__(name)
        self.num_arm = num_arm
        self.dim_context = dim_context
        self.device = device

        self.collector = collector
        self.criterion = criterion
        self.reduce = reduce
        self.reg = reg

        self.m = ensemble_size
        self.std = math.sqrt(var_perturb)
        self.tao = dim_context if tao is None else tao

        # Build m independent model/optimizer copies
        self.models = [copy.deepcopy(model).to(device) for _ in range(self.m)]
        self.optimizers = [self._clone_optimizer_like(optimizer, self.models[j].parameters())
                           for j in range(self.m)]

        self.z_hist = [[] for _ in range(self.m)]

        self.step = 0
        self.clear()

    def _clone_optimizer_like(self, opt_proto, params):
        opt_cls = type(opt_proto)
        # try to reconstruct kwargs from the first param group
        g0 = opt_proto.param_groups[0]
        keys = set(list(opt_proto.defaults.keys()) + list(g0.keys()))
        # remove keys that are not valid kwargs
        blacklist = {'params', 'foreach', 'maximize', 'differentiable', 'capturable'}
        kwargs = {k: g0[k] for k in keys if k in g0 and k not in blacklist}
        # ensure lr present if available in defaults
        if 'lr' not in kwargs and 'lr' in opt_proto.defaults:
            kwargs['lr'] = opt_proto.defaults['lr']
        return opt_cls(params, **kwargs)

    def clear(self):
        self.collector.clear()
        self.Design = 0   # warmup helper
        self.last_cxt = 0

    @torch.no_grad()
    def choose_arm(self, context: torch.Tensor) -> int:
        if self.step <= self.tao:
            return self.step % self.num_arm

        j_t = torch.randint(low=0, high=self.m, size=(1,)).item()
        self.models[j_t].eval()
        scores = self.models[j_t](context.to(self.device)).squeeze(dim=1)
        arm_to_pull = torch.argmax(scores).item()
        return arm_to_pull

    def receive_reward(self, arm, context, reward):
        # Log data
        self.collector.collect_data(context, arm, reward)
        self.last_cxt = context

        if self.step < self.tao:
            self.Design += 0.25 * context.view(-1, 1) @ context.view(-1, 1).T

        with torch.no_grad():
            new_z = torch.normal(mean=0.0, std=self.std, size=(self.m,))
        for j in range(self.m):
            self.z_hist[j].append(float(new_z[j].item()))

    def _loss_with_fixed_perturb(self, model, contexts, rewards, z_tensor):
        logits = model(contexts).squeeze(dim=1)
        base = self.criterion(logits, rewards)
        perturb = -(z_tensor.to(logits.device) * logits).mean()
        return base + perturb

    def update_model(self, num_iter: int):
        self.step += 1
        if self.reduce and (self.step % self.reduce != 0):
            return

        # FULL history tensors (in insertion order)
        contexts, rewards = self.collector.fetch_batch()
        if len(rewards) == 0:
            return
        contexts = torch.stack(contexts, dim=0).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)

        # Sanity: ensure each z_hist[j] matches history length
        N = rewards.shape[0]
        for j in range(self.m):
            if len(self.z_hist[j]) != N:
                raise RuntimeError(f"z_hist[{j}] length {len(self.z_hist[j])} "
                                   f"!= data length {N}. Check collector ordering.")

        # Decaying weight decay like your other agents
        for opt in self.optimizers:
            for pg in opt.param_groups:
                pg['weight_decay'] = self.reg / max(self.step, 1)

        for j in range(self.m):
            model_j = self.models[j]
            opt_j = self.optimizers[j]
            model_j.train()
            z_j = torch.tensor(self.z_hist[j], dtype=torch.float32, device=self.device)

            for _ in range(num_iter):
                model_j.zero_grad()
                loss = self._loss_with_fixed_perturb(model_j, contexts, rewards, z_j)
                loss.backward()
                opt_j.step()

                if torch.isfinite(loss) and loss.item() < 1e-3:
                    break
            assert not torch.isnan(loss), f'Loss is NaN in ensemble member {j}!'
