import numpy as np
import pandas as pd
import scipy.sparse
import matplotlib.pyplot as plt
import itertools
from scipy.linalg import subspace_angles

def handle_one_sample_M(data: pd.DataFrame):
    data = data.copy()

    # plot M_max against sqrt(n)
    data_M = data[['n', 'a', 'b', 'M_max']]
    C_max = np.mean(data['M_max'] / np.sqrt(data['n']))
    print("C_max for M: ", C_max)
    x_curve = np.linspace(100, 1000, 100)
    y_curve = np.sqrt(x_curve) * C_max
    df_curve = pd.DataFrame({'n': x_curve, 'y': y_curve})

    df_points = data_M[['n', 'M_max']]

    ax = df_curve.plot(x='n', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='n', y='M_max', color='red', ax=ax)

    plt.xlabel('n')
    plt.ylabel('|M|')
    plt.grid(True)
    plt.show()

def handle_one_sample_gamma(data: pd.DataFrame):
    data = data.copy()

    # plot 2/gamma against e^(...)
    data_gamma = data[['n', 'a', 'b', 'gamma_max']]
    data_gamma = data_gamma[data_gamma['gamma_max'] < 0.5]
    a = data_gamma['a']
    b = data_gamma['b']
    gamma_max = data_gamma['gamma_max']
    C_max = np.max( (a-b)**2 / (a+b) / np.log(2/gamma_max))
    x_curve = np.linspace(100, 1100, 100)
    a_curve = 0.06 * x_curve
    b_curve = 0.04 * x_curve

    y_curve = np.exp((a_curve - b_curve)**2 / (a_curve + b_curve) / C_max)
    df_curve = pd.DataFrame({'n': x_curve, 'y': y_curve})

    df_points = data_gamma[['n', 'gamma_max']]
    df_points['y'] = 2 / df_points['gamma_max']

    ax = df_curve.plot(x='n', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='n', y='y', color='red', ax=ax)

    plt.xlabel('n')
    plt.ylabel('2/gamma')
    plt.grid(True)
    plt.show()

def handle_M_single_n(data: pd.DataFrame, n: int):
    data = data.copy()

    # plot M_max against sqrt(m)
    data_M = data[['m', 'a', 'b', 'M_max']]
    C_max = np.mean(data['M_max'] * np.sqrt(data['m']))
    print("C_max for M: ", C_max)
    x_curve = np.linspace(1, 9, 100)
    y_curve = C_max / np.sqrt(x_curve)
    df_curve = pd.DataFrame({'m': x_curve, 'y': y_curve})

    df_points = data_M[['m', 'M_max']]

    ax = df_curve.plot(x='m', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='m', y='M_max', color='red', ax=ax)

    plt.xlabel('m')
    plt.ylabel('|M|')
    plt.title("|M| for n = " + str(n))
    plt.grid(True)
    plt.show()

def handle_sin_W_single_n(data: pd.DataFrame, n: int):
    data = data.copy()

    # plot M_max against sqrt(m)
    data_M = data[['m', 'a', 'b', 'sin_max']]
    data_M = data_M[data_M['sin_max'] < 0.9]
    C_max = np.mean(data['sin_max'] * np.sqrt(data['m']))
    print("C_max for M: ", C_max)
    x_curve = np.linspace(1, 9, 100)
    y_curve = C_max / np.sqrt(x_curve)
    df_curve = pd.DataFrame({'m': x_curve, 'y': y_curve})

    df_points = data_M[['m', 'sin_max']]

    ax = df_curve.plot(x='m', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='m', y='sin_max', color='red', ax=ax)

    plt.xlabel('m')
    plt.ylabel('sin (W,W_E)')
    plt.title("sin (W,W_E) for n = " + str(n))
    plt.grid(True)
    plt.show()

def handle_gamma_single_n(data: pd.DataFrame, n: int):
    data = data.copy()

    # plot M_max against sqrt(m)
    #data['gamma_max'] = data['gamma_mean']
    data_gamma = data[['m', 'a', 'b', 'gamma_max']]
    data_gamma = data_gamma[data_gamma['gamma_max'] < 0.4]
    data_gamma = data_gamma[data_gamma['gamma_max'] > 0.001]

    power = 0.71

    gamma_max = data_gamma['gamma_max']
    C_max = np.max( np.log(2/gamma_max) / np.power(data_gamma['m'], power))

    x_curve = np.linspace(1, 9, 100)

    y_curve = C_max * np.power(x_curve, power)
    df_curve = pd.DataFrame({'m': x_curve, 'y': y_curve})

    df_points = data_gamma[['m', 'gamma_max']]
    df_points['y'] = np.log(2 / df_points['gamma_max'])

    ax = df_curve.plot(x='m', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='m', y='y', color='red', ax=ax)

    plt.xlabel('m')
    plt.ylabel('log(2 / gamma)')
    plt.title("log(2/gamma) vs m for n = " + str(n) + ", t=" + str(power))
    plt.grid(True)
    plt.show()

def handle_one_sample_gamma_for_m(data: pd.DataFrame, m):
    data = data.copy()

    #data['gamma_max'] = data['gamma_mean']
    # plot 2/gamma against e^(...)
    data_gamma = data[['n', 'a', 'b', 'gamma_max']]
    data_gamma = data_gamma[data_gamma['gamma_max'] < 0.5]
    data_gamma = data_gamma[data_gamma['gamma_max'] > 0.001]
    a = data_gamma['a']
    b = data_gamma['b']
    gamma_max = data_gamma['gamma_max']
    C_max = np.mean( (a-b)**2 / (a+b) / np.log(2/gamma_max))
    x_curve = np.linspace(100, np.max(data_gamma['n'])+100, 100)
    a_curve = 0.06 * x_curve
    b_curve = 0.04 * x_curve

    y_curve = (a_curve - b_curve)**2 / (a_curve + b_curve) / C_max
    df_curve = pd.DataFrame({'n': x_curve, 'y': y_curve})

    df_points = data_gamma[['n', 'gamma_max']]
    df_points['y'] = np.log(2 / df_points['gamma_max'])

    ax = df_curve.plot(x='n', y='y', color='blue', legend=None)
    df_points.plot.scatter(x='n', y='y', color='red', ax=ax)

    plt.xlabel('n')
    plt.ylabel('log(2/gamma)')
    plt.title("log(2/gamma) vs n for m = " + str(m))
    plt.grid(True)
    plt.show()

def main():
    raw_df = pd.read_csv("raw_data_Neurips25.csv")

    one_df = raw_df[raw_df["m"] == 1]

    #handle_one_sample_gamma(one_df)
    #handle_one_sample_M(one_df)

    n_values = raw_df['n'].unique()
    for n_value in n_values:
        single_n_df = raw_df[raw_df['n'] == n_value]
        #handle_M_single_n(single_n_df, n_value)
        #handle_sin_W_single_n(single_n_df, n_value)
        #handle_gamma_single_n(single_n_df, n_value)

    m_values = [3,4,5,6,7,8]
    for m_value in m_values:
        one_df = raw_df[raw_df["m"] == m_value]
        handle_one_sample_gamma_for_m(one_df, m_value)

if __name__ == "__main__":
    main()

