import numpy as np
import matplotlib.pyplot as plt
import itertools

def f(x):
    """映射函数 f(x) = exp(x^3 - 2x^2) - 1"""
    return np.exp(x**3 - 2*x**2) - 1

def fixed_point_iteration(x0, target, max_iter=1000, tol=1e-6):
    """
    对向量 x0 做不动点迭代，目标不动点为 target，
    当 ||x_n - target|| < tol 或 达到 max_iter 时停止。
    返回最终迭代点和每次的 loss 列表。
    """
    x = x0.copy()
    losses = []
    for _ in range(max_iter):
        loss = np.linalg.norm(x - target)
        losses.append(loss)
        if loss < tol:
            break
        x = f(x)
    return x, losses

def sample_initial_and_target(choice, low1, high1, low2, high2):
    """
    根据 choice（二进制元组），对每个维度随机采样初始值并生成目标不动点。
    choice[i] == 0 → 区间 [low1, high1]，target[i]=0
    choice[i] == 1 → 区间 [low2, high2]，target[i]=c
    返回 (initial_vector, target_vector)
    """
    dims = len(choice)
    initial = np.zeros(dims)
    target  = np.zeros(dims)
    for i, bit in enumerate(choice):
        if bit == 0:
            initial[i] = np.random.uniform(low1, high1)
            target[i]  = 0.0
        else:
            initial[i] = np.random.uniform(low2, high2)
            target[i]  = -0.9104
    return initial, target

def main():
    np.random.seed(42)

    dims = 5
    # 定义两个区间
    low1, high1 = -0.1,  0.1
    low2, high2 = -0.9104 - 0.1, -0.9104 + 0.1

    # 用 itertools.product 生成 32 种 0/1 的组合
    all_choices = list(itertools.product([0,1], repeat=dims))

    plt.figure(figsize=(10, 6))
    for idx, choice in enumerate(all_choices):
        # 采样一个初始向量和对应目标
        x0, target = sample_initial_and_target(choice, low1, high1, low2, high2)
        # 迭代并获取 loss 曲线
        _, losses = fixed_point_iteration(x0, target, max_iter=25, tol=1e-6)
        # 画出 loss 曲线
        plt.plot(losses, alpha=0.6)

    plt.xlabel('iteration times')
    plt.ylabel('Loss (‖x - target‖)')
    plt.title('Loss Curve')
    plt.yscale('log')            # 建议用对数刻度便于观察
    plt.grid(True, which='both', ls='--', lw=0.5)
    plt.tight_layout()
    plt.savefig("my_plot.pdf")

if __name__ == "__main__":
    main()
