import torch
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from backpack import backpack, extend
from backpack.extensions import BatchGrad


# Neural network model
class NeuralNetwork(nn.Module):
    def __init__(self, input_dim, node):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, node),
            nn.ReLU(),
            # nn.Linear(node, node),
            # nn.ReLU(),
            nn.Linear(node, 1)
        )
    def forward(self, x):
        return self.layers(x)
    
class SimpleCNN(nn.Module):
    def __init__(self, K):  # <-- K를 인자로 받음
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=K, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):  # x: (batch=K, channels=K, H, W)
        x = self.conv_layers(x)
        return self.fc_layers(x)
    
class SimpleCNN(nn.Module):
    def __init__(self, K):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=K, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):  # x: (batch=K, channels=K, H, W)
        x = self.conv_layers(x)
        return self.fc_layers(x)


# Base Agent class
class DeepFL:
    def __init__(self, env, T, params):
        self.env = env
        self.d = env.d
        self.K = env.K
        self.T = T
        self.X = None

        for attr, val in params.items():
            setattr(self, attr, val)

        if self.mode=="CNN":
            self.model = SimpleCNN(self.K).to(self.device)
        elif self.mode=="MLP":
            self.model = NeuralNetwork(self.d,self.node).to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
        self.loss_fn = nn.MSELoss()
        self.memory = []

    def get_arms(self, X):
        self.X = X

    def update(self, t, x, y):
        x = x.clone().detach().to(self.device).unsqueeze(0)
        y = torch.tensor([y], dtype=torch.float32, device=self.device)

        self.memory.append((x, y))

        if len(self.memory) > self.batch_size:
            self.memory.pop(0)

        if len(self.memory) == self.batch_size:
            x_batch, y_batch = zip(*self.memory)
            x_batch = torch.cat(x_batch)
            y_batch = torch.cat(y_batch)

            if self.perturbation_std > 0:
                y_batch += torch.normal(0, self.perturbation_std, size=y_batch.shape, device=self.device)

            self.model.train()
            self.optimizer.zero_grad()
            predictions = self.model(x_batch).squeeze()
            loss = self.loss_fn(predictions, y_batch)
            loss.backward()
            self.optimizer.step()

    def get_arm(self, t):
        self.model.eval()
        with torch.no_grad():
            predictions = self.model(self.X).squeeze()
        if self.type == 1:
            predictions += torch.normal(mean=0, std=self.perturbation_std / np.sqrt(t+1), size=(self.K,), device=self.device)
        return torch.argmax(predictions).item()
    
    @staticmethod
    def print():
        return "DeepFL" 


class DeepFP(DeepFL):
    def __init__(self, env, T, params):
        super().__init__(env, T, params)
        for attr, val in params.items():
            setattr(self, attr, val)
        self.invcov = torch.eye(self.d, device=self.device)

    def update(self, t, x, y):
        x = x.clone().detach().to(self.device).unsqueeze(0)  # [1, d] or [1, 1, 28, 28]
        y = torch.tensor([y], dtype=torch.float32, device=self.device)
        self.memory.append((x, y))

        if len(self.memory) > self.batch_size:
            self.memory.pop(0)

        if len(self.memory) == self.batch_size:
            x_batch, y_batch = zip(*self.memory)
            x_batch = torch.cat(x_batch)
            y_batch = torch.cat(y_batch)

            self.model.train()
            self.optimizer.zero_grad()
            predictions = self.model(x_batch).squeeze()
            loss = self.loss_fn(predictions, y_batch)
            loss.backward()
            self.optimizer.step()

    def get_arm(self, t):
        self.model.eval()
        with torch.no_grad():
            # CNN 모드
            if self.mode == "CNN":
                noise = torch.normal(
                    mean=0,
                    std=self.perturbation_std / np.sqrt(t + 1),
                    size=self.X[0].shape,  # [1, 28, 28]
                    device=self.device
                )  # 단일 perturbation

                tilde_x = self.X.clone()  # [K, 1, 28, 28]

                for i in range(self.K):
                    # 각 arm에 대해서 해당 위치만 perturbation 적용
                    tilde_x[i] += noise * (i == torch.arange(self.K, device=self.device)).float().view(-1, 1, 1)

                out = self.model(tilde_x).squeeze()
                return torch.argmax(out).item()

            # MLP 모드
            else:
                eta = torch.normal(mean=0, std=self.perturbation_std / np.sqrt(t+1), size=self.X.shape, device=self.device)
                tilde_x = self.X + eta

                d_prime = self.d // self.K
                mask = torch.zeros_like(tilde_x)
                for i in range(self.K):
                    start = i * d_prime
                    end = (i + 1) * d_prime
                    mask[i, start:end] = 1
                tilde_x = tilde_x * mask

                predictions = self.model(tilde_x).squeeze()
                return torch.argmax(predictions).item()

    @staticmethod
    def print():
        return "DeepFP"

    

class EpsilonGreedy(DeepFL):
    def __init__(self, env, T, params):
        super().__init__(env, T, params)
        for attr, val in params.items():
            setattr(self, attr, val)

    def get_arm(self, t):
        predictions = torch.zeros(self.K, device=self.device)
        if torch.rand(1).item() < 0.05 * torch.sqrt(torch.tensor(self.T / (t + 1), device=self.device)) / 2:
            predictions[torch.randint(self.K, (1,), device=self.device)] = torch.inf
        else:
            self.model.eval()
            with torch.no_grad():
                predictions = self.model(self.X.to(self.device)).squeeze() 
        return torch.argmax(predictions).item()
    
    @staticmethod
    def print():
        return "epsilon"



class NeuralTS(DeepFL):
    def __init__(self, env, T, params, lamdba=1, nu=1):
        super().__init__(env, T, params)
        for attr, val in params.items():
            setattr(self, attr, val)

        if self.mode == "CNN":
            base_model = SimpleCNN(self.K).to(self.device)
        elif self.mode == "MLP":
            base_model = NeuralNetwork(self.d, self.node).to(self.device)

        self.lamdba = lamdba
        self.nu = nu
        self.model = extend(base_model)
        self.total_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self.U = self.lamdba * torch.ones((self.total_param,), device=self.device)

        self.context_list = None
        self.reward_list = None
        self.len = 0
        self.loss_fn = nn.MSELoss()
        self.delay = getattr(self, 'delay', 1)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)

    def get_arm(self, t):
        mu = self.model(self.X)
        sum_mu = torch.sum(mu)

        with backpack(BatchGrad()):
            sum_mu.backward()

        g_list = torch.cat(
            [p.grad_batch.flatten(start_dim=1).detach() for p in self.model.parameters()],
            dim=1
        )

        sigma = torch.sqrt(torch.sum(self.lamdba * self.nu * g_list * g_list / self.U, dim=1))

        if self.style == 'TS':
            sampled = torch.normal(mu.view(-1), sigma.view(-1))
        elif self.style == 'UCB':
            sampled = mu.view(-1) + sigma.view(-1)

        arm = torch.argmax(sampled)
        self.U += g_list[arm] * g_list[arm]
        return arm.item()

    @staticmethod
    def print():
        return f"NeuralTSDiag"

