import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import copy
from .base import _agent


class NeuralES(_agent):
    """
    Neural Ensemble Sampling (contextual bandit)
    """
    def __init__(self,
                 num_arm, dim_context,
                 model, optimizer,
                 criterion, collector,
                 nu,
                 batch_size=64,
                 image=False,
                 reg=0.0,
                 reduce=10,
                 device='cpu',
                 ensemble_size=10,
                 tau = 100,
                 beta=None,
                 epochs_per_round=1):
        self.num_arm = int(num_arm)
        self.dim_context = int(dim_context)
        self.device = device
        self.image = bool(image)

        self.m = int(ensemble_size)
        self.beta = float(beta if beta is not None else nu)
        self.tau = tau
        self.batch_size = int(batch_size)
        self.epochs_per_round = int(epochs_per_round)

        self.criterion = criterion
        self.collector = collector
        self.reg = float(reg)
        self.reduce = int(reduce)

        # Build ensemble from the single provided model/optimizer
        self.models = self._clone_model_n(model, self.m, self.device)
        self.optims = [self._rebuild_optimizer_like(optimizer, m.parameters()) for m in self.models]

        self.clear()

    def clear(self):
        self.t = 1

        self.buffers_X = [torch.empty(0, self.dim_context, dtype=torch.float32, device=self.device) for _ in range(self.m)]
        self.buffers_y = [torch.empty(0, dtype=torch.float32, device=self.device) for _ in range(self.m)]

        self.last_cxt = None
        self.last_reward = None
        print('beta = ', str(self.beta))
        print('m = ', self.m)

    @torch.no_grad()
    def choose_arm(self, context):
        if not torch.is_tensor(context):
            context = torch.as_tensor(context, dtype=torch.float32, device=self.device)
        else:
            context = context.to(self.device)
        assert context.shape == (self.num_arm, self.dim_context), f"Expected {(self.num_arm, self.dim_context)}, got {tuple(context.shape)}"

        if self.t <= self.tau:
            arm_to_pull = torch.randint(low=0, high=self.num_arm, size=(1,), device=self.device).item()
            self.last_cxt = context[arm_to_pull]
            return arm_to_pull

        j_t = torch.randint(low=0, high=self.m, size=(1,), device=self.device).item()
        model = self.models[j_t]
        with torch.inference_mode():
            preds = model(context).view(-1)
        arm_to_pull = torch.argmax(preds).item()

        self.last_cxt = context[arm_to_pull]
        return arm_to_pull

    def receive_reward(self, arm, context, reward):
        self.last_reward = float(reward)

    def _append_sample(self, x: torch.Tensor, y: float):
        z = torch.randn(self.m, device=self.device) * self.beta
        y_tilde = torch.tensor(y, device=self.device, dtype=torch.float32) + z

        x = x.detach() if x.is_leaf else x
        x = x.to(self.device, dtype=torch.float32).view(1, -1)

        for j in range(self.m):
            self.buffers_X[j] = torch.cat([self.buffers_X[j], x], dim=0)
            self.buffers_y[j] = torch.cat([self.buffers_y[j], y_tilde[j].view(1)], dim=0)

    def _sgd_train_one(self, model: nn.Module, optim: torch.optim.Optimizer, X: torch.Tensor, y: torch.Tensor):
        if X.shape[0] == 0:
            return
        ds = TensorDataset(X, y.view(-1, 1))
        dl = DataLoader(ds, batch_size=min(self.batch_size, X.shape[0]), shuffle=True, drop_last=False)
        model.train()
        for _ in range(self.epochs_per_round):
            for xb, yb in dl:
                pred = model(xb)
                loss = self.criterion(pred, yb)
                optim.zero_grad(set_to_none=True)
                loss.backward()
                optim.step()

    def update_model(self, num_iter=None):
        if self.last_cxt is None:
            return
        x = self.last_cxt
        y = float(self.last_reward)

        self._append_sample(x, y)

        for j in range(self.m):
            self._sgd_train_one(self.models[j], self.optims[j], self.buffers_X[j], self.buffers_y[j])

        self.t += 1
    
    def _clone_model_n(self, model: nn.Module, n: int, device: str):
        clones = []
        for _ in range(n):
            m = copy.deepcopy(model)
            m.to(device)
            clones.append(m)
        return clones

    def _rebuild_optimizer_like(self, base_opt: torch.optim.Optimizer, params):
        opt_cls = type(base_opt)
        defaults = dict(base_opt.defaults)
        return opt_cls(params, **defaults)
