import logging
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from utils import sigmoid, dot_mu

logger = logging.getLogger(__name__)


class NeuralModel(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)


class LogisticModel(nn.Module):
    def __init__(self, input_dim=5):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1, bias=False)

    def forward(self, x):
        return self.linear(x)


class LogisticBanditAgent:
    def __init__(
        self,
        device,
        dataset,
        d=20,
        K=5,
        m=20,
        S=1.0,
        kappa=10.0,
        nu=0.1,
        lambda_reg=0.1,
        batch_size=50,
        algorithm="alg1",
        lr=1e-2,
        seed=0,
    ):
        self.device = device
        self.dataset = dataset
        self.d = d
        self.K = K
        self.m = m
        self.lambda_reg = lambda_reg
        self.nu = nu
        self.S = S
        self.kappa = kappa
        self.batch_size = batch_size
        self.algorithm = algorithm
        self.lr = lr
        self.memory = []

        torch.manual_seed(seed)
        np.random.seed(seed)

        if algorithm in ["alg1", "alg2", "alg3", "alg6", "alg7"]:
            self.model = NeuralModel(input_dim=d, hidden_dim=m).to(device)
            self.p = sum(p.numel() for p in self.model.parameters())
            self.theta0 = self._get_flat_params()
            self.V = kappa * lambda_reg * torch.ones(self.p, device=device)
            self.W = lambda_reg * torch.ones(self.p, device=device)
            self.G_sum = torch.zeros(self.p, device=device)
        else:
            self.model = LogisticModel(input_dim=d).to(device)
            self.theta0 = self._get_flat_params()
            self.bar_theta = self.theta0
            self.bar_V = kappa * lambda_reg * torch.eye(d, device=device)
            self.bar_V_inv = torch.eye(d, device=device) / (kappa * lambda_reg)
            self.bar_W = lambda_reg * torch.eye(d, device=device)
            self.bar_W_inv = torch.eye(d, device=device) / lambda_reg
            self.bar_G_sum = torch.zeros((d, d), device=device)

        self.loss_fn = nn.BCEWithLogitsLoss()

    def _get_flat_params(self):
        return torch.cat([p.data.flatten() for p in self.model.parameters()])

    def _get_flat_grad(self, x):
        grads = []
        for xi in x:
            self.model.zero_grad()
            out = self.model(xi.unsqueeze(0))
            out.backward()
            grads.append(torch.cat([p.grad.view(-1) for p in self.model.parameters()])[None])
        return torch.cat(grads).detach()

    def update_design_matrix(self, x):
        if self.algorithm in ["alg1", "alg2", "alg3", "alg6", "alg7"]:
            self._update_design_nn(x)
        else:
            self._update_design_logistic(x)

    def _update_design_nn(self, x):
        grad = self._get_flat_grad(x)
        g2 = grad * grad
        g_sum = g2.sum(dim=0)
        self.V += g_sum
        self.G_sum += g_sum

        if self.algorithm in ["alg2", "alg7"]:
            f_val = self.model(x).detach()
            dmu = dot_mu(f_val)
            self.W += (dmu * g2).sum(dim=0)


    def _update_design_logistic(self, x):
        v = torch.einsum('bi,bj->bij', x, x)
        v_sum = v.sum(dim=0)
        self.bar_V += v_sum
        if self.algorithm == "alg4":
            try:
                self.bar_V_inv = torch.linalg.inv(self.bar_V)
            except RuntimeError:
                self.bar_V_inv = torch.linalg.pinv(self.bar_V)
        self.bar_G_sum += v_sum
        if self.algorithm == "alg5":
            z = x @ self.bar_theta.unsqueeze(-1)
            dmu = dot_mu(z).unsqueeze(-1)
            self.bar_W += (dmu * v).sum(dim=0)
            try:
                self.bar_W_inv = torch.linalg.inv(self.bar_W)
            except RuntimeError:
                self.bar_W_inv = torch.linalg.pinv(self.bar_W)

    def update_model(self, num_epochs=100, update_batch_size=50):
        if not self.memory:
            return
        sampler = SubsetRandomSampler(self.memory)
        loader = DataLoader(self.dataset, batch_size=update_batch_size, sampler=sampler)

        optimizer = optim.SGD(
            self.model.parameters(),
            lr=self.lr,
            weight_decay=self.lambda_reg / 10
        )

        train_fn = (
            self._train_neural
            if self.algorithm in ["alg1", "alg2", "alg3", "alg6", "alg7"]
            else self._train_logistic
        )
        train_fn(loader, optimizer, num_epochs)



    def _train_neural(self, loader, optimizer, epochs):
        self.model.train()
        for _ in range(epochs):
            for xb, yb in loader:
                optimizer.zero_grad()
                loss = self.loss_fn(self.model(xb), yb)
                loss.backward()
                optimizer.step()

    def _train_logistic(self, loader, optimizer, epochs):
        self.model.train()
        for _ in range(epochs):
            for xb, yb in loader:
                optimizer.zero_grad()
                loss = self.loss_fn(self.model(xb), yb)
                loss.backward()
                optimizer.step()
        with torch.no_grad():
            self.bar_theta = self.model.linear.weight.data.flatten()

    def select_action(self, context):
        B, K, _ = context.shape
        ucb_vals = torch.zeros((B, K, 1), device=self.device)

        if self.algorithm in ["alg1", "alg2", "alg3", "alg6", "alg7"]:
            for i in range(K):
                x = context[:, i]               # (B, d)
                f = self.model(x).detach()      # (B, 1)
                mu = sigmoid(f)
                grad = self._get_flat_grad(x)   # (B, p)
                g2 = grad * grad

                if self.algorithm in ["alg1", "alg3", "alg6"]:
                    norm = torch.sqrt((g2 / self.V).sum(dim=1, keepdim=True))
                else:  # alg2, alg7
                    norm = torch.sqrt((g2 / self.W).sum(dim=1, keepdim=True))

                if self.algorithm == "alg1":
                    A = (1 / (4 * self.lambda_reg)) * self.G_sum + 1
                    logdet = torch.sqrt(torch.sum(torch.log(A)))
                    bonus = self.nu * (self.kappa**0.5) * (self.S**2) * logdet * norm
                    ucb = mu + bonus

                elif self.algorithm == "alg2":
                    A = (1 / (4 * self.lambda_reg)) * self.G_sum + 1
                    logdet = torch.sqrt(torch.sum(torch.log(A)))
                    theta = self._get_flat_params()
                    inner = grad @ (theta - self.theta0).unsqueeze(-1)  # (B, 1)
                    bonus = self.nu * (self.S**2) * logdet * norm
                    ucb = inner + bonus

                elif self.algorithm == "alg3":
                    A = (1 / (self.kappa * self.lambda_reg)) * self.G_sum + 1
                    logdet = torch.sqrt(torch.sum(torch.log(A)))
                    bonus = self.nu * (
                        (self.kappa**0.5) * self.S + self.kappa * logdet
                    ) * norm
                    ucb = f + bonus

                elif self.algorithm == "alg6":
                    A = (1 / (4 * self.lambda_reg)) * self.G_sum + 1
                    logdet = torch.sqrt(torch.sum(torch.log(A)))
                    bonus = self.nu * (self.kappa**0.5) * (self.S**2) * logdet * norm
                    eps = torch.randn_like(f)            # ~ N(0, I)
                    ucb = f + eps * bonus               # N(f, bonus^2)

                elif self.algorithm == "alg7":
                    A = (1 / (4 * self.lambda_reg)) * self.G_sum + 1
                    logdet = torch.sqrt(torch.sum(torch.log(A)))
                    theta = self._get_flat_params()
                    inner = grad @ (theta - self.theta0).unsqueeze(-1)  # (B, 1)
                    bonus = self.nu * (self.S**2) * logdet * norm
                    eps = torch.randn_like(inner)        # ~ N(0, I)
                    ucb = inner + eps * bonus           # N(inner, bonus^2)

                ucb_vals[:, i] = ucb

        else:
            for i in range(K):
                x = context[:, i]
                f = x @ self.bar_theta.unsqueeze(-1)
                mu = sigmoid(f)
                if self.algorithm == "alg4":
                    norm = torch.sqrt(
                        torch.einsum('bi,ij,bj->b', x, self.bar_V_inv, x)
                    ).unsqueeze(-1)
                elif self.algorithm == "alg5":
                    norm = torch.sqrt(
                        torch.einsum('bi,ij,bj->b', x, self.bar_W_inv, x)
                    ).unsqueeze(-1)
                A = (1 / (4 * self.lambda_reg)) * self.bar_G_sum + torch.eye(
                    self.d, device=self.device
                )
                logdet = torch.logdet(A)
                if self.algorithm == "alg4":
                    bonus = (
                        self.nu
                        * (self.kappa**0.5)
                        * (self.S**2)
                        * (logdet + (self.d**0.5))
                        * norm
                    )
                    ucb = mu + bonus
                elif self.algorithm == "alg5":
                    bonus = self.nu * self.S * (logdet + (self.d**0.5)) * norm
                    ucb = f + bonus
                ucb_vals[:, i] = ucb

        choice = torch.argmax(ucb_vals, dim=1)
        return choice, ucb_vals


    def record(self, batch_t, chosen_index):
        base = (batch_t - 1) * self.batch_size
        for j in range(self.batch_size):
            idx = base + self.K * j + chosen_index[j][0]
            self.memory.append(idx)
