import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
import pandas as pd
from itertools import product
import random

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
directory = 'data_0516_tmp10/'
from outer_loop.exploration import *
import torch.nn.init as init

# ==== 1. Value Function ====

class Function:
    def __init__(self):
        pass

    def __call__(self, x: torch.Tensor):
        raise NotImplementedError

class SmoothPolynomial(Function):
    def __init__(self, seed=42):
        super().__init__()
        self.coeff = torch.randn(6)

    def __call__(self, x):
        x1, x2 = x[:, 0], x[:, 1]
        val = (self.coeff[0]*x1**2 + self.coeff[1]*x2**2 + 
               self.coeff[2]*x1*x2 + self.coeff[3]*x1 + 
               self.coeff[4]*x2 + self.coeff[5])
        return val

def init_weights(module):
    if isinstance(module, nn.Linear):
        init.normal_(module.weight)
        if module.bias is not None:
          init.zeros_(module.bias)

class FixedNN(Function):
    def __init__(self, seed=0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 16), nn.ReLU(),
            nn.Linear(16, 16), nn.ReLU(),
            nn.Linear(16, 1)
        )
        self.net.apply(init_weights)
        for p in self.net.parameters():
            p.requires_grad_(False)
        

    def __call__(self, x):
        return self.net(x).squeeze(-1)

class Spiky(Function):
    def __init__(self, peaks=[(0.75,0.75), (-0.75,-0.75)], height=[10, 5], scope=[10,10]):
        super().__init__()
        self.peaks = torch.tensor(peaks, dtype=torch.float32)
        self.height = height
        self.scope = scope

    def __call__(self, x):
        val = torch.zeros(x.shape[0])
        for i, p in enumerate(self.peaks):
            dist = torch.sum((x - p)**2, dim=1)
            val += self.height[i] * torch.exp(-dist*self.scope[i])
        return val

class SlopeSpiky(Spiky):
    def __init__(self, peaks=[(0.75,0.75), (-0.75,-0.75)], height=[10, 5]):
        super().__init__(height=[10, 10])

    def __call__(self, x):
        base_val = super().__call__(x)
        bias = x[:, 0] + x[:, 1] 
        return base_val + bias

class StairSpiky(Spiky):
    def __init__(self, peaks=[(0.75,0.75), (-0.75,-0.75)], height=[10, 5]):
        super().__init__(height=[10, 10])

    def __call__(self, x):
        base_val = super().__call__(x)
        bias = 5 * ((x[:, 0] > 0) & (x[:, 1] > 0)).float() - 5 * ((x[:, 0] < 0) & (x[:, 1] < 0)).float()
        return base_val + bias

class RandomSpiky(Spiky):
    def __init__(self, peaks=[(0.75,0.75), (-0.75,-0.75)], height=[10, 5]):
        peak_nums = np.random.randint(1,5)
        peaks = np.random.rand(peak_nums, 2) * 3 - 1.5
        height = np.random.rand(peak_nums) * 10
        scope = np.random.rand(peak_nums) * 10+5
        super().__init__(peaks, height, scope)

# ==== 2. Gradient Optimizer ====

def gradient_step(x, fn, lr=0.01, noise_scale=0.0):
    x = x.clone().detach().requires_grad_(True)
    val = fn(x.unsqueeze(0))
    val.backward()
    grad = x.grad
    x_new = x + lr * grad + torch.randn_like(x) * noise_scale
    return x_new.detach(), val.item()

def gradient_step_with_exploration(x, fn, lr=0.01, exploration=None,):
    x = x.clone().detach().requires_grad_(True)
    val = fn(x.unsqueeze(0))
    val.backward()
    grad = x.grad
    grad_step = x + lr * grad

    if exploration:
        explore_point = exploration.explore(num_samples=500)
        x_new = 0.7 * grad_step + 0.3 * explore_point
    else:
        x_new = grad_step

    return x_new.detach(), val.item()

# ==== 3. Visualization ====

def visualize_surface_and_path(fn, path, title="Function Surface"):
    fig = plt.figure(figsize=(7, 6))
    ax = fig.add_subplot(111, projection='3d')

    X, Y = np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))
    grid = torch.tensor(np.stack([X.ravel(), Y.ravel()], axis=1), dtype=torch.float32)
    Z = fn(grid).detach().numpy().reshape(100, 100)

    ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.5)
    
    if len(path) > 0:
        path_np = torch.stack(path).numpy()
        path_vals = [fn(p.unsqueeze(0)).item() for p in path]
        ax.plot(path_np[:, 0], path_np[:, 1], path_vals, 'r.-', label='Optimization Path')
        ax.legend()
    ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("value")
    # plt.show()
    plt.tight_layout()
    plt.savefig(directory+title+".png")
    plt.close()


def visualize_path_heatmap(fn, path, bounds=(-1, 1), res=200, title="Optimization Path Heatmap"):
    path_tensor = torch.stack(path)
    values = fn(path_tensor).detach().cpu().numpy()

    x = torch.linspace(bounds[0], bounds[1], res)
    y = torch.linspace(bounds[0], bounds[1], res)
    mesh = torch.cartesian_prod(x, y)
    z = fn(mesh).reshape(res, res).detach().cpu().numpy()

    path_np = path_tensor.cpu().numpy()
    x_path, y_path = path_np[:, 0], path_np[:, 1]

    plt.figure(figsize=(7, 6))
    plt.imshow(z, origin='lower', extent=[bounds[0], bounds[1], bounds[0], bounds[1]],
               cmap='viridis', aspect='auto', alpha=0.5)
    sc = plt.scatter(x_path, y_path, c=range(len(path)), cmap='hot', s=50)
    plt.colorbar(sc, label='Iteration')
    plt.title(title)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(directory+title+".png")
    plt.close()

class NormalizeWrapper(Function):
    def __init__(self, base_func, bounds=(-1, 1), res=100):
        super().__init__()
        self.fn = base_func
        grid = torch.linspace(bounds[0], bounds[1], res)
        mesh = torch.cartesian_prod(grid, grid)
        vals = self.fn(mesh)
        self.min_val = vals.min()
        self.max_val = vals.max()

    def __call__(self, x):
        val = self.fn(x)
        normed = (val - self.min_val) / (self.max_val - self.min_val + 1e-8)
        return normed

# ==== Example run ====
class customlogger():
    def record(self, name, value):
        # print(f"{name}: {value}")
        pass

class NewBaseline:
    def __init__(self, weight_dim=2, device='cpu', lower=-1, upper=1, method='CEM'):
        self.weight_dim = weight_dim
        self.device = device
        self.lower = lower
        self.upper = upper
        self.cnt = 0
        self.batch = 5
        self.mean = torch.zeros(weight_dim, device=device)
        self.std = torch.ones_like(self.mean) * 0.5
        self.cov = torch.eye(2) * 0.5
        self.values = []
        self.method = method
        self.update_samples(bounds=(lower, upper))

    def update_samples(self, bounds=(-1, 1)):
        if self.method == 'CEM':
            samples = torch.randn(self.batch, self.weight_dim) * self.std + self.mean
            self.samples = samples.clamp(*bounds)
        else:
            cov = self.cov + 1e-5 * torch.eye(self.weight_dim)  # 防止数值不稳定
            cov = (cov + cov.T) / 2  # 保证协方差矩阵对称
                # 检查并加更大正则，如果协方差矩阵仍不合法
            try:
                dist = torch.distributions.MultivariateNormal(self.mean, covariance_matrix=cov)
                self.samples = dist.sample((self.batch,))
            except Exception:
                cov += 1e-2 * torch.eye(self.weight_dim)
                dist = torch.distributions.MultivariateNormal(self.mean, covariance_matrix=cov)
                self.samples = dist.sample((self.batch,))
            # samples = torch.distributions.MultivariateNormal(self.mean, covariance_matrix=cov).sample((self.batch,))
            # self.samples = samples.clamp(*bounds)

    def update_cem(self, bounds=(-1, 1), elite_frac=0.2):
        self.values = torch.tensor(self.values, device=self.device)
        num_elite = 2
        elite_idxs = torch.topk(self.values, num_elite).indices
        elite_samples = self.samples[elite_idxs]
        self.mean = elite_samples.mean(0)
        self.std = elite_samples.std(0) + 1e-6
        # print(self.samples)

    def update_cma(self, bounds=(-1, 1), elite_frac=0.2):
        self.values = torch.tensor(self.values, device=self.device)
        num_elite = 2
        elite_idxs = torch.topk(self.values, num_elite).indices
        elite_samples = self.samples[elite_idxs]
        self.mean = elite_samples.mean(0)
        centered = elite_samples - self.mean
        self.cov = (centered.T @ centered) / elite_samples.shape[0]

    def explore(self, ):
        return self.samples[self.cnt]

    def update_value(self, value):
        self.values.append(value)
        self.cnt += 1
        if len(self.values) >= self.batch:
            if self.method == 'CEM':
                self.update_cem()
            else:
                self.update_cma()
            self.update_samples()
            self.values = []
            self.cnt = 0


def cma_step(func, mean, cov, num_samples=100, elite_frac=0.2, bounds=(-1, 1)):
    dim = mean.shape[0]
    cov = cov + 1e-5 * torch.eye(dim)  # Regularization
    samples = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample((num_samples,))
    samples = samples.clamp(*bounds)

    values = func(samples).detach()
    num_elite = int(num_samples * elite_frac)
    elite_idxs = torch.topk(values, num_elite).indices
    elite_samples = samples[elite_idxs]

    new_mean = elite_samples.mean(0)
    centered = elite_samples - new_mean
    new_cov = (centered.T @ centered) / elite_samples.shape[0]

    return new_mean, new_cov, new_mean


# Evaluation function
def run_experiment(func, explorer_class, start_points, num_trials=1, steps=100, lr=0.01, device='cpu', condition='none', name='none', method='topk', visualize=False):
    values = []

    for trial in range(num_trials):
        x = start_points[trial]  # random start in [-1, 1]
        path = [x.clone()]
        value = [func(x.unsqueeze(0)).item()]
        explorer = NewBaseline(weight_dim=2, device=device, lower=-1, upper=1, method=explorer_class)
        last_value = 0

        for i in range(steps):
            x, _ = gradient_step(x, func, lr=0.05)
            x = x.clamp(-1, 1)  # restrict to bounds

            if i % 10 == 0:
                # do CMA / CEM exploration
                if condition == 'none':
                    explore_point = explorer.explore()
                    x = explore_point
                    value_ = func(x.unsqueeze(0)).item()
                    explorer.update_value(value_)
                if condition == 'performance':
                    current_value = func(path[-1].unsqueeze(0)).item()
                    P = 1 - current_value
                    delta_perf = current_value - last_value
                    if delta_perf < 0.0001 and np.random.rand() < P:
                        explore_point = explorer.explore()
                        x = explore_point
                        value_ = func(x.unsqueeze(0)).item()
                        explorer.update_value(value_)
                    last_value = current_value

            path.append(x.detach().cpu())
            value.append(func(x.unsqueeze(0)).item())

        final_value = func(x.unsqueeze(0)).item()
        values.append(final_value)

        if trial == 0 and visualize:
            visualize_path_heatmap(func, path, title=name)

    return values

configs = [
    {
        'explorer_class': 'CEM',
        'condition': 'performance',
        'method': 'explore',
        'name': 'CEM'
    },
    {
        'explorer_class': 'CMA',
        'condition': 'performance',
        'method': 'explore',
        'name': 'CMA'
    },
]

if __name__ == "__main__":
    seed = 1
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    device = 'cpu'

    # Setup test combinations
    functions = [RandomSpiky, FixedNN, SmoothPolynomial] #, Spiky, SlopeSpiky, StairSpiky]
    explorers = [SinglePrediction] #RandomExploration, NoExploration, MaxDistanceExploration, SinglePrediction, UCBExploration, EntropyExploration]
    conditions = ['performance'] # 'none', 'performance'
    # method = 'topk'
    methods = ['explore'] # 'topk', 'explore'
    
    # Run all combinations
    results = {'mean': {}, 'std': {}, 'min': {}, 'max': {}}
    values = {}
    start_points = torch.rand((10, 2), device=device) * 2 - 1
    for func_class in functions:
        for i in range(10):
            func = func_class()
            func = NormalizeWrapper(func)
            func_name = func_class.__name__ + str(i)
            for config in configs:
                explorer_class = config['explorer_class']
                condition = config['condition']
                method = config['method']
                name = config['name']
                viz = False
                # if i == 2:
                #     viz = True
                stats = run_experiment(func,explorer_class, start_points, device='cpu', 
                    name=f"{name}_{func_name}", 
                    condition=condition, method=method, visualize=viz
                )
                values.setdefault(name, {})[func_name] = stats
                # for key in results:
                #     results[key].setdefault(metric_name, {})[func_name] = stats[key]
                print(f"{name} - {func_name}: {stats}")

    # Convert to DataFrames and save to CSV
    print(values)
            
    np.save(directory+'results.npy', values)
    # res = pd.DataFrame(values).T
    # res.to_csv(directory+'mean_results.csv')
    # mean_df = pd.DataFrame(results['mean']).T
    # std_df = pd.DataFrame(results['std']).T
    # min_df = pd.DataFrame(results['min']).T
    # max_df = pd.DataFrame(results['max']).T

    # mean_df.to_csv(directory+'mean_results.csv')
    # std_df.to_csv(directory+'std_results.csv')
    # min_df.to_csv(directory+'min_results.csv')
    # max_df.to_csv(directory+'max_results.csv')


    # # plot
    # for func_class in functions:
    #     func = func_class()
    #     func = NormalizeWrapper(func)
    #     func_name = func_class.__name__
    #     path = []
    #     visualize_surface_and_path(func, path, title=func_name)

