import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.dates as mdates

# if futures data, annualized= 50660716 / 1224 * 252
# if stocks data, annualized=number of rows of the whole dataset /7 * 252
def calculate_volatility(data, window=20, annualized=252):
    """Calculate rolling volatility of returns."""
    if type(data) == pd.DataFrame:
        returns = data.loc[:,'u2_Mid-Price_1'].pct_change().to_frame()
    elif type(data) == np.ndarray:
        returns = pd.DataFrame(data[41,:]).pct_change()
    volatility_vector = returns.rolling(window=window).std() * np.sqrt(annualized)  # Annualized
    print(f'Max volatility element: {np.max(volatility_vector)}, Min volatility element: {np.min(volatility_vector)}')
    volatility =  returns.std() * np.sqrt(annualized)  # Annualized
    print(f'Average Volatility: {volatility}')
    return volatility, volatility_vector

def plot_volatility(data, vol):
    """Plot volatility comparison."""
    # fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    fig, ax = plt.subplots(figsize=(12,10))
    if type(data) == pd.DataFrame:
        name = 'CHF'
        start_date = pd.to_datetime("2019-10-08")
        end_date = pd.to_datetime("2024-01-16")
        time_index = pd.date_range(start=start_date, end=end_date, periods=len(vol))

    else:
        name = 'FI'
        print('Chose FI')
        start_date = pd.to_datetime("2010-06-01")
        end_date = pd.to_datetime("2010-06-14")
        time_index = pd.date_range(start=start_date, end=end_date, periods=len(vol))

        df = pd.DataFrame(time_index, columns=['times'])

        x = df.values.flatten()
        y = vol.flatten()
        # x = df['times'].to_numpy().flatten()
        # y = vol.flatten()
        print(x.shape, y.shape)
        ax = sns.lineplot(x=x, y=y, ax=ax)
        ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))  # Set major ticks to monthly
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        fig.autofmt_xdate()
        # plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=1))  # Set major ticks to monthly
        # plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        # plt.gcf().autofmt_xdate()

    plt.title(f'{name} Volatility')
    plt.ylabel('Annualized Volatility')
    plt.xlabel('Date')
    plt.ylim(-50000,1000000)

    # Hide the right and top spines, show the bottom and left spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)


    # date_formatter = DateFormatter("%Y-%m-%d")
    # ax.xaxis.set_major_formatter(date_formatter)
    # fig.autofmt_xdate()
    # plt.tight_layout()
    f_name = f'volatility_{name}.pdf'
    plt.savefig(f_name)
    print(f'saved figure {f_name}')

    # # Plot stock volatility
    # ax1.plot(stock_data.index, stock_vol, label='Stock Volatility', color='blue')
    # ax1.set_title('Stock Volatility')
    # ax1.set_ylabel('Annualized Volatility')
    # ax1.legend()
    
    # # Plot futures volatility
    # ax2.plot(futures_data.index, futures_vol, label='Futures Volatility', color='red')
    # ax2.set_title('Futures Volatility')
    # ax2.set_xlabel('Date')
    # ax2.set_ylabel('Annualized Volatility')
    # ax2.legend()
    
    # # Format x-axis
    # date_formatter = DateFormatter("%Y-%m-%d")
    # ax2.xaxis.set_major_formatter(date_formatter)
    # fig.autofmt_xdate()
    
    # plt.tight_layout()
    # plt.savefig('volatility.png')

if __name__ == '__main__':
    print('reading dataset')
    data = np.loadtxt('./data/FI-2010/BenchmarkDatasets/Auction/1.Auction_Zscore/Auction_Zscore_Training/Train_Dst_Auction_ZScore_CF_7.txt')
    t = type(data)

    # Calculate volatility
    print('Calculating Volatility')
    if t == np.ndarray:
        v, v_vector = calculate_volatility(data, annualized=data.shape[1] / 7 * 252)
    elif t == pd.DataFrame:
        v, v_vector = calculate_volatility(data, annualized=len(data) / 1224 * 252)
    print(f'Total Volatility {v}')

    # Plot volatility
    print('Plot')
    # plot_volatility(stock_data, stock_volatility)
    plot_volatility(data, np.array(v_vector))

    # Print summary statistics
    print("Volatility Summary:")
    print(v_vector.describe())
    print(f'Volatility mean {v_vector.mean()}')

    # avg_vol_diff = (futures_volatility.mean() - stock_volatility.mean()) / stock_volatility.mean() * 100
    # print(f"\nAverage volatility percent difference: {avg_vol_diff:.2f}%")