import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.cm as cm

class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.net(x)

class ExplorationModule(nn.Module):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=0, upper=1):
        super().__init__()
        self.device = device
        self.weight_dim = weight_dim
        self.lower = lower
        self.upper = upper

    def train(self, x, performance=None):
        self.performance = performance
        self.weights = x
        return 0

    def explore(self, num_samples=1000, method='topk', tmp = 1, plot=False, filename='', func=None):
        '''
        method: can be 'addition' or 'topk'
        '''
        if method == 'addition':
            return self.explore_addition(num_samples, tmp, with_performance=True, plot=plot, filename=filename, func=func)
        elif method == 'topk':
            return self.explore_topk(num_samples, tmp, plot=plot, filename=filename, func=func)
        elif method == 'explore':
            return self.explore_addition(num_samples, tmp, with_performance=False, plot=plot, filename=filename, func=func)

    def explore_addition(self, num_samples=1000, tmp = 1, with_performance=False, plot=False, filename='', func=None):
        sampled_points = torch.rand((num_samples, self.weight_dim), device=self.device) * (self.upper - self.lower) + self.lower
        p = self.random_metric(sampled_points, tmp)
        if with_performance:
            p += self.performance_metric(sampled_points, tmp)
        idx = p.multinomial(num_samples=1, replacement=False)[0]
        if plot:
            self.plot(func, sampled_points.detach().numpy(), p.detach().numpy(), sampled_points[idx].detach().numpy(), filename=filename)
        return sampled_points[idx]

    def explore_topk(self, num_samples=1000, tmp = 1, plot=False, filename='', func=None):
        sampled_points_ori = torch.rand((num_samples, self.weight_dim), device=self.device) * (self.upper - self.lower) + self.lower
        p_ori = self.random_metric(sampled_points_ori, tmp)
        value, idx = torch.topk(p_ori, 10)
        sampled_points = sampled_points_ori[idx]
        p = self.performance_metric(sampled_points, tmp)
        idx = p.multinomial(num_samples=1, replacement=False)[0]
        if plot:
            self.plot(func, sampled_points_ori.detach().numpy(), p_ori.detach().numpy(), sampled_points[idx].detach().numpy(), filename=filename)
        return sampled_points[idx]

    def plot(self, base_func, points, prob, chosen, filename=''):
        x = torch.linspace(-1, 1, 200)
        y = torch.linspace(-1, 1, 200)
        mesh = torch.cartesian_prod(x, y)
        z = base_func(mesh).reshape(200, 200).detach().cpu().numpy()

        plt.figure(figsize=(7, 6))
        plt.imshow(z, origin='lower', extent=[-1,1,-1,1],
                cmap='gray', aspect='auto')
        sc = plt.scatter(points[..., 0], points[..., 1], c=prob, cmap='plasma', s=50)
        plt.scatter(chosen[..., 0], chosen[..., 1], c='lawngreen', marker='x', s=100)
        weights = self.weights.cpu().numpy()
        plt.scatter(weights[..., 0], weights[..., 1], c='red', s=100)
        plt.colorbar(sc, label='prob')
        plt.title(filename)
        plt.xlabel("x")
        plt.ylabel("y")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(filename+".png")
        plt.close()

        print(filename)

    def random_metric(self, sampled_points, tmp = 1):
        pass

    def performance_metric(self, sampled_points, tmp=1):
        if self.performance is not None:
            dists = torch.norm(sampled_points[:, None, :] - self.weights[None, :, :], dim=2)  # [1000, n]
            nearest_indices = torch.argmin(dists, dim=1) # [1000]
            closest_perf = self.performance[nearest_indices]
            p = torch.softmax(closest_perf * tmp, dim=0) # [1000]
            return p

class RandomExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=0, upper=1):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'random'

    def explore(self, num_samples=1000, method='argmax', tmp = 1, with_performance=False):
        return torch.rand(self.weight_dim).to(self.device)*(self.upper - self.lower) + self.lower

class SinglePrediction(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=0, upper=1, lr=1e-3):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'prediction'

        # Fixed network: parameters frozen
        self.fixed_net = SimpleMLP(weight_dim, hidden_dim).to(device)
        for param in self.fixed_net.parameters():
            param.requires_grad = False

        # Prediction network: trainable
        self.pred_net = SimpleMLP(weight_dim, hidden_dim).to(device)

        # Optimizer only for prediction network
        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr=lr)

        self.criterion = nn.MSELoss()

    def forward(self, reward_weights):
        pred = self.pred_net(reward_weights)
        fixed = self.fixed_net(reward_weights)
        return pred, fixed

    def compute_loss(self, reward_weights):
        pred, fixed = self.forward(reward_weights)
        return self.criterion(pred, fixed)

    def train(self, reward_weights, performance=None):
        """
        Performs one training step: forward, loss, backward, optimizer step
        reward_weights: Tensor of shape [batch_size, weight_dim]
        Returns: scalar loss value (float)
        """
        self.weights = reward_weights
        self.performance = performance
        losses = [1000,999,998,997,996]
        i = 0
        while losses[-5] >= losses[-1]: # overfit to existing network
            self.optimizer.zero_grad()
            loss = self.compute_loss(reward_weights)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
            i += 1
            if i > 1000:
                break
        loss = losses[-1]
        return loss

    def random_metric(self, sampled_points, tmp=1):
        pred = self.pred_net(sampled_points)
        fixed = self.fixed_net(sampled_points)
        error = torch.abs(pred - fixed).squeeze()
        p = torch.softmax(tmp*error, dim=0)
        return p

class MaxDistanceExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=0, upper=1):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'maxdis'
        self.weights = None

    def random_metric(self, sampled_points, tmp=1):
        dists = torch.norm(sampled_points[:, None, :] - self.weights[None, :, :], dim=2)  # [1000, n]
        min_dists, nearest_indices = torch.min(dists, dim=1)  # [1000], [1000]
        p = torch.softmax(tmp*min_dists, dim=0)
        return p

class UncertaintyExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=-1, upper=1):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'uncertainty'
        self.weights = None

    def random_metric(self, sampled_points, tmp=1):
        dists = torch.cdist(sampled_points, self.weights)  # [num_samples, n]
        uncertainty = dists.mean(dim=1)  # 简单用均值距离表示不确定性
        p = torch.softmax(tmp * uncertainty, dim=0)
        return p

class EntropyExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=-1, upper=1):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'entropy'
        self.weights = None
        self.performance = None

    def random_metric(self, sampled_points, tmp=1):
        dists = torch.cdist(sampled_points, self.weights)  # [num_samples, n]
        logits = -dists / tmp
        probs = torch.softmax(logits, dim=1)
        entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1)  # [num_samples]
        p = torch.softmax(entropy * tmp, dim=0)
        return p

class UCBExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=-1, upper=1, beta=0.5):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'ucb'
        self.weights = None
        self.performance = None
        self.beta = beta

    def random_metric(self, sampled_points, tmp=1):
        dists = torch.cdist(sampled_points, self.weights)  # [num_samples, n]
        weights = torch.softmax(-dists, dim=1)  # [num_samples, n]
        perf = self.performance.to(self.device)
        mu = (weights * perf[None, :]).sum(dim=1)
        sigma = torch.sqrt((weights * (perf[None, :] - mu[:, None])**2).sum(dim=1) + 1e-6)
        score = mu + self.beta * sigma
        p = torch.softmax(score * tmp, dim=0)
        return p

class NoExploration(ExplorationModule):
    def __init__(self, weight_dim, hidden_dim=64, device='cpu', lower=-1, upper=1, beta=0.5):
        super().__init__(weight_dim, hidden_dim, device, lower, upper)
        self.name = 'none'

    def train(self, x, performance=None):
        self.weights = x[-1]
        return 0

    def explore(self, num_samples=1000, method='argmax', tmp=1.0, with_performance=False):
        return self.weights

if __name__ == "__main__":
    class customlogger():
        def record(self, name, value):
            print(f"{name}: {value}")

    torch.manual_seed(0)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    module = MaxDistanceExploration(weight_dim=3, device=device)

    rand_weights = torch.rand(50, 3).to(device)
    # torch.save(rand_weights, 'rand.pt')
    train_weights = [rand_weights[0].unsqueeze(0)]

    for step in range(50):
        batch = torch.cat(train_weights).to(device)
        performance_batch = torch.rand_like(batch)[..., 0]
        module.train(batch, performance=performance_batch)
        new_weight = module.explore2(method='softmax_perf', with_performance=True)
        train_weights.append(new_weight.unsqueeze(0))
    
    batch = torch.cat(train_weights).to(device)
    # torch.save(batch, 'maxdis.pt')