import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

class DifferentialGameEnv:
    def __init__(self, horizon=25, num_fields=5, seed=42):
        self.horizon = horizon
        self.state = None
        self.step_count = 0
        self.agent_num = 2
        self.obs_dim = 1
        self.action_dim = 1
        self.rng = np.random.RandomState(seed)
        self.num_fields = num_fields

        self.fields = []
        for _ in range(num_fields):
            center = self.rng.uniform(-5, 5, size=2)
            height = self.rng.uniform(5, 10)
            width = self.rng.uniform(1, 2)
            self.fields.append({'center': center, 'height': height, 'width': width})

    def reset(self, eval=False):
        self.state = np.zeros([2, 1])
        self.step_count = 0
        return self.state

    def step(self, action_np):
        self.step_count += 1
        x1 = np.clip(action_np[0], -5, 5)
        x2 = np.clip(action_np[1], -5, 5)
        reward = 0
        for field in self.fields:
            cx, cy = field['center']
            h = field['height']
            sigma = field['width']
            reward += h * np.exp(-((x1 - cx) ** 2 + (x2 - cy) ** 2) / (sigma ** 2))

        terminated = False
        truncated = self.step_count >= self.horizon
        return np.zeros([2, 1]), reward, terminated, truncated, {}

    def render(self, mode='human'):
        pass

    def close(self):
        pass

    def plot_reward_landscape(self):
        x = np.linspace(-5, 5, 200)
        y = np.linspace(-5, 5, 200)
        X, Y = np.meshgrid(x, y)

        Z = np.zeros_like(X)
        for field in self.fields:
            cx, cy = field['center']
            h = field['height']
            sigma = field['width']
            Z += h * np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (sigma ** 2))

        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(X, Y, Z, cmap=cm.viridis)
        ax.set_title("Reward Landscape with {} Potential Fields".format(self.num_fields))
        plt.show()

    def find_true_global_max(self, resolution=500):
        x = np.linspace(-5, 5, resolution)
        y = np.linspace(-5, 5, resolution)
        X, Y = np.meshgrid(x, y)

        Z = np.zeros_like(X)
        for field in self.fields:
            cx, cy = field['center']
            h = field['height']
            sigma = field['width']
            Z += h * np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (sigma ** 2))

        max_idx = np.unravel_index(np.argmax(Z), Z.shape)
        x_max, y_max = X[max_idx], Y[max_idx]
        z_max = Z[max_idx]
        return x_max, y_max, z_max
