import matplotlib.dates as mdates
import matplotlib.pyplot as plt


from datetime import datetime

def plot_cw(
    data_batch,
    portfolio_value_list,
    base_portfolio_value_list,
    figure_name="output.png",
):
    # Convert 'yyyy-mm-dd' strings to datetime objects
    plot_dates = [
        datetime.strptime(i["output_date"], "%Y-%m-%d") for i in data_batch[1:]
    ]

    plt.figure(figsize=(12, 6))
    plt.plot(plot_dates, portfolio_value_list, label="Ours")
    plt.plot(plot_dates, base_portfolio_value_list, label="Equal Weighted", alpha=0.7)

    # Format the x-axis
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
    plt.gca().xaxis.set_major_locator(
        mdates.MonthLocator(interval=3)
    )  # tick every 3 months
    plt.gcf().autofmt_xdate()  # rotate and align dates

    plt.title("Portfolio Performance Comparison")
    plt.xlabel("Date")
    plt.ylabel("Cumulative Wealth")
    plt.legend()
    plt.grid(True)
    plt.savefig(figure_name)
