import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from itertools import cycle

initial_x = 4.0
initial_y = -1
initial_point = torch.tensor([initial_x, initial_y], requires_grad=True)

def heatmap_3D(x, y, Z, trajectories, args):
    X, Y = np.meshgrid(x, y)
    plt.figure(figsize=(8, 6))
    contour_plot = plt.contourf(X, Y, Z, levels=20, cmap='viridis', alpha=0.3)
    plt.colorbar(contour_plot, label='Function Value')
    if trajectories:
        line_styles = ['-', '--', '-.', ':']  
        colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']   
        style_cycle = cycle(line_styles)
        color_cycle = cycle(colors)
        for name, path in trajectories:
            path = np.array(path)
            plt.plot(path[0], path[1], linestyle=next(style_cycle), color=next(color_cycle), label=f'{name}', alpha=0.9)
    plt.title('line')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend()
    plt.savefig(f'../image/{args.task_name}/{args.id}.png')  