import os
import sys


current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)  # Ensure parent directory is searched first
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
# import torch.special as special
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

from model import get_model
from data import get_data
# from evaluation import test
from options import options
from utils import simple_lapsed_time
from tqdm import tqdm
from math import log
from generalization_gap import get_generalization_gap
import warnings
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import RidgeCV, LinearRegression, Ridge, Lasso
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score, KFold
from sklearn.metrics import mean_squared_error
from sklearn.base import BaseEstimator, TransformerMixin
from numpy.polynomial.chebyshev import chebvander
from torchvision.utils import save_image


class Polynomial(nn.Module):

    def __init__(self, degree):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(torch.randn(degree+1))

    def forward(self, x):
        y = self.coeffs[0] * torch.ones(x.size(0)).to(device)
        for i in range(1, self.degree + 1):
            y += self.coeffs[i] * (x ** i)
        return y


class ChebyshevPolynomial(Polynomial):

    def __init__(self, degree):
        super().__init__(degree=degree)

    def forward(self, x):
        y = 0
        for i in range(self.degree + 1):
            y += self.coeffs[i] * special.chebyshev_polynomial_u(x, i)
        return y


def fit_pytorch(x, y, degree, iters=3000, cheb=False):
    x = x.to(device)
    y = y.to(device)
    if cheb:
        poly = ChebyshevPolynomial(degree).to(device)
    else:
        poly = Polynomial(degree).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(poly.parameters(), lr=0.1, weight_decay=0)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
    for it in range(iters):
        pred = poly(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (it + 1) % 100 == 0:
            print(f'Iter {it + 1}, loss = {loss:.6f}, coeffs = {poly.coeffs}')
    pred = poly(x)
    loss = criterion(pred, y)
    return poly.coeffs.detach().cpu().numpy(), loss.item(), pred.detach().cpu().numpy()


def fit_121(x, y, degree):
    coeffs = np.polyfit(x, y, degree)
    poly_func = np.poly1d(coeffs)
    y_pred = poly_func(x)
    mse = mean_squared_error(y, y_pred)
    return mse, y_pred


def best_degree_AIC(x, y, max_degree = 20, resolution = 500, visualize = True, i = 0):
    all_AIC = []
    all_mse = []
    all_y_pred = []
    n = resolution
    for degree in range(max_degree + 1):
        k = degree + 1
        mse, y_pred = fit_121(x, y, degree)
        all_mse.append(mse)
        all_y_pred.append(y_pred)
        AIC = 2 * k + n * log(mse)
        print(f'degree: {degree}\n2k: {2*k}, mse: {mse:.3f}, nln(mse): {n * log(mse):.3f}, AIC: {AIC:.3f}')
        all_AIC.append(AIC)

    best_deg = np.argmin(all_AIC)
    if visualize:
        plt.figure(figsize=(12, 10))
        all_degrees = np.arange(max_degree + 1)
        all_2k = 2 * (all_degrees + 1)
        all_nlnmse = [n * log(mse) for mse in all_mse]
        # 1. 2k value changes
        plt.subplot(2, 2, 1)
        plt.plot(all_degrees, all_2k, 'bo-')
        plt.title('2k vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('2k')
        plt.grid(True)
        
        # 2. MSE changes
        plt.subplot(2, 2, 2)
        plt.plot(all_degrees, all_mse, 'ro-')
        plt.title('MSE vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('MSE')
        plt.yscale('log')  # Use logarithmic scale for clarity
        plt.grid(True)
        
        # 3. n*ln(MSE) changes
        plt.subplot(2, 2, 3)
        plt.plot(all_degrees, all_nlnmse, 'go-')
        plt.title('n*ln(MSE) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('n*ln(MSE)')
        plt.grid(True)
        
        # 4. AIC changes
        plt.subplot(2, 2, 4)
        plt.plot(all_degrees, all_AIC, 'mo-')
        plt.scatter(best_deg, all_AIC[best_deg], s=100, c='red', zorder=5)
        plt.title(f'AIC vs Degree (Best: degree {best_deg})')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('AIC')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'./poly/plots/viz_{i}.png')
    return best_deg, all_mse[best_deg], all_y_pred[best_deg]

def best_degree_BIC(x, y, max_degree = 20, resolution = 500, visualize = False, i = 0):
    all_BIC = []
    all_mse = []
    all_y_pred = []
    n = resolution
    for degree in range(max_degree + 1):
        k = degree + 1
        mse, y_pred = fit_121(x, y, degree)
        all_mse.append(mse)
        all_y_pred.append(y_pred)
        BIC = log(n) * k + n * log(mse)
        all_BIC.append(BIC)
    # print(all_BIC)
    best_deg = np.argmin(all_BIC)
    if visualize:
        plt.figure(figsize=(12, 10))
        all_degrees = np.arange(max_degree + 1)
        all_klnn = [(degree + 1) * log(n) for degree in all_degrees]
        all_nlnmse = [n * log(mse) for mse in all_mse]
        plt.subplot(2, 2, 1)
        plt.plot(all_degrees, all_klnn, 'bo-')
        plt.title('kln(n) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('kln(n)')
        plt.grid(True)
        
        plt.subplot(2, 2, 2)
        plt.plot(all_degrees, all_mse, 'ro-')
        plt.title('MSE vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('MSE')
        plt.yscale('log')  
        plt.grid(True)
        
        plt.subplot(2, 2, 3)
        plt.plot(all_degrees, all_nlnmse, 'go-')
        plt.title('n*ln(MSE) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('n*ln(MSE)')
        plt.grid(True)
        
        plt.subplot(2, 2, 4)
        plt.plot(all_degrees, all_BIC, 'mo-')
        plt.scatter(best_deg, all_BIC[best_deg], s=100, c='red', zorder=5)
        plt.title(f'BIC vs Degree (Best: degree {best_deg})')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('BIC')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'./poly/plots/viz_{i}_bic.png')
    return best_deg, all_mse[best_deg], all_y_pred[best_deg]


def numpy_split(lst, ratio):
    arr = np.array(lst)
    n = len(arr)
    indices = np.arange(0, n * ratio) * (1 / ratio)
    indices = np.unique(np.floor(indices).astype(int))
    mask = np.zeros(n, dtype=bool)
    mask[indices] = True
    return arr[mask].tolist(), arr[~mask].tolist()



def best_degree_cross_validation(x, y, max_degree=80, n_folds=5):
    """
    Optimized polynomial degree selection method
    """
    # Data normalization (key step)
    x_scaled = StandardScaler().fit_transform(x.reshape(-1, 1))
    
    # Use regularized cross-validation
    degrees = np.arange(max_degree)
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=args.set_seed)
    
    # Store average MSE for each degree
    degree_mses = []
    
    for degree in degrees:
        model = make_pipeline(
            PolynomialFeatures(degree=degree),
            LinearRegression()
        )
        
        # Cross-validation
        mse_scores = -cross_val_score(
            model, x_scaled, y, 
            cv=kf, scoring='neg_mean_squared_error'
        )
        degree_mses.append(np.mean(mse_scores))
    
    # Select the best degree
    best_deg = degrees[np.argmin(degree_mses)]
    
    # Train the final model on the full dataset
    final_model = make_pipeline(
        PolynomialFeatures(degree=best_deg),
        LinearRegression()
        # RidgeCV(alphas=np.logspace(-6, 6, 13))
    ).fit(x_scaled, y)
    
    # Get predictions
    y_pred = final_model.predict(x_scaled)
    best_mse = mean_squared_error(y, y_pred)
    
    return best_deg, best_mse, y_pred


def best_degree_cross_validation_new(x, y, max_degree=40, n_folds=5):
    # Only reshape the original features (standardization is done inside the Pipeline)
    
    x = x.reshape(-1, 1)
    # print(x)
    # Identify indices of data points in three regions
    left_bound = -0.085
    right_bound = 1.085
    left_val_mask = (x.ravel() >= -0.1) & (x.ravel() < left_bound)    
    mid_mask = (x.ravel() >= left_bound) & (x.ravel() < right_bound)             
    x_mid = x[mid_mask]
    y_mid = y[mid_mask]
    right_val_mask = (x.ravel() >= right_bound) & (x.ravel() <= 1.1)    
    left_val_indices = np.where(left_val_mask)[0]
    mid_indices = np.where(mid_mask)[0]
    right_val_indices = np.where(right_val_mask)[0]
    
    # x = np.linspace(-10, 10, 500).reshape(-1, 1)

    def custom_cv():
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=args.set_seed)
        for train_index_mid, test_index_mid in kf.split(mid_indices):
            train_index = mid_indices[train_index_mid]
            
            test_index = np.concatenate([
                mid_indices[test_index_mid],
                left_val_indices,
                right_val_indices
            ])
            yield train_index, test_index
    degrees = np.arange(max_degree + 1)
    degree_mses = []
    
    for degree in degrees:
        # Correct Pipeline: only keep one estimator
        model = make_pipeline(
            PolynomialFeatures(degree=degree, include_bias=True),
            StandardScaler(),  
            LinearRegression()

        )
        
        mse_scores = -cross_val_score(
            model, x, y, 
            cv=custom_cv(), scoring='neg_mean_squared_error'
        )
        degree_mses.append(np.mean(mse_scores))
    best_deg = degrees[np.argmin(degree_mses)]
    
    final_model = make_pipeline(
        PolynomialFeatures(degree=best_deg, include_bias=True),
        StandardScaler(),

        LinearRegression()
    ).fit(x, y)
    

    lr_model = final_model.named_steps['linearregression']
    coefficients = [lr_model.intercept_] + lr_model.coef_.ravel().tolist()[1:]
    y_pred = final_model.predict(x)
    best_mse = mean_squared_error(y, y_pred)
    
    return best_deg, best_mse, y_pred


def weighted_degree(x, y, max_degree=40):
    x = x * 5
    x = x.reshape(-1, 1)
    model = make_pipeline(
        PolynomialFeatures(degree=max_degree, include_bias=False),
        StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    
    # Get the steps in the pipeline
    poly = model.named_steps['polynomialfeatures']
    scaler = model.named_steps['standardscaler']
    lr_model = model.named_steps['linearregression']
    
    # Get the coefficients and intercept after scaling
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_
    
    # Calculate original coefficients
    # Reverse transformation based on standardization: coef_original = coef_scaled / scaler.scale_
    coef_original = coef_scaled / scaler.scale_
    # print('after std:', coef_scaled)
    # print('scale:', scaler.scale_)
    # print('original:', coef_original)
    
    # intercept_original = intercept_scaled - np.sum(coef_scaled * scaler.mean_ / scaler.scale_)
    intercept_original = intercept_scaled - np.dot(coef_scaled, scaler.mean_ / scaler.scale_)
    
    coefficients_std = np.concatenate([[intercept_scaled], coef_scaled])
    coefficients_original = np.concatenate([[intercept_original], coef_original])


    squre = np.square(coefficients_original)
    c_norm_2 = np.sum(squre)
    degrees = np.arange(max_degree + 1)
    weighted_degree = np.dot(squre, degrees) / c_norm_2
    
    # print('weighted degree:', weighted_degree)
    us_degree = int(weighted_degree) + 1
    model_weighted = make_pipeline(
        PolynomialFeatures(degree=us_degree, include_bias=False),
        StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    y_pred = model_weighted.predict(x)
    mse = mean_squared_error(y, y_pred)
        
    return weighted_degree, mse, y_pred


def weighted_degree_chebyshev(x, y, resolution=500, max_degree=40):
    # x = np.linspace(-1, 1, resolution)
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  
    x = x.reshape(-1, 1)
    
    class ChebyshevFeatures(BaseEstimator, TransformerMixin):
        def __init__(self, degree):
            self.degree = degree
            
        def fit(self, X, y=None):
            return self
            
        def transform(self, X):
            return chebvander(X.flatten(), self.degree)[:, 1:]  # Exclude the constant term because linear regression has an intercep

    model = make_pipeline(
        ChebyshevFeatures(degree=max_degree),
        # StandardScaler(),
        LinearRegression()
    ).fit(x, y)


    lr_model = model.named_steps['linearregression']
    
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_

    coefficients = np.concatenate([[intercept_scaled], coef_scaled])
    
    degrees = np.arange(max_degree + 1)
    abs_coeffs = np.abs(coefficients)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)
    
    y_pred = model.predict(x)
    mse = mean_squared_error(y, y_pred)
    return weighted_degree, mse, y_pred, coefficients
    # return 0, 0, y

def weighted_degree_chebyshev_no_norm(x, y, resolution=500, max_degree=40):
    # x = np.linspace(-1, 1, resolution)
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  
    x = x.reshape(-1, 1)
    
    # Custom Chebyshev basis expansion
    class ChebyshevFeatures(BaseEstimator, TransformerMixin):
        def __init__(self, degree):
            self.degree = degree
            
        def fit(self, X, y=None):
            return self
            
        def transform(self, X):
            # Use Chebyshev polynomials to generate features
            return chebvander(X.flatten(), self.degree)[:, 1:]  # Exclude the constant term because linear regression has an intercept
    # restore start
    # Create a pipeline using Chebyshev polynomials
    model = make_pipeline(
        ChebyshevFeatures(degree=max_degree),
        # StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    
    # Get the steps in the pipeline
    # cheby = model.named_steps['chebyshevfeatures']

    lr_model = model.named_steps['linearregression']
    
    # Get the coefficients and intercept
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_

    coefficients = np.concatenate([[intercept_scaled], coef_scaled])
    
    degrees = np.arange(max_degree + 1)
    abs_coeffs = np.abs(coefficients)
    weighted_degree = np.dot(abs_coeffs, degrees)
    
    y_pred = model.predict(x)
    mse = mean_squared_error(y, y_pred)
    return weighted_degree, mse, y_pred, coefficients
    # return 0, 0, y



class Polynomial(nn.Module):

    def __init__(self, degree):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(torch.randn(degree+1))

    def forward(self, x):
        y = self.coeffs[0] * torch.ones(x.size(0)).to(device)
        for i in range(1, self.degree + 1):
            y += self.coeffs[i] * (x ** i)
        return y


class ChebyshevPolynomial(Polynomial):

    def __init__(self, degree):
        super().__init__(degree=degree)

    def forward(self, x):
        y = 0
        for i in range(self.degree + 1):
            y += self.coeffs[i] * special.chebyshev_polynomial_u(x, i)
        return y


def fit_pytorch(x, y, degree, iters=3000, cheb=False):
    x = x.to(device)
    y = y.to(device)
    if cheb:
        poly = ChebyshevPolynomial(degree).to(device)
    else:
        poly = Polynomial(degree).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(poly.parameters(), lr=0.1, weight_decay=0)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
    for it in range(iters):
        pred = poly(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (it + 1) % 100 == 0:
            print(f'Iter {it + 1}, loss = {loss:.6f}, coeffs = {poly.coeffs}')
    pred = poly(x)
    loss = criterion(pred, y)
    return poly.coeffs.detach().cpu().numpy(), loss.item(), pred.detach().cpu().numpy()

def weighted_degree_pytorch(x, y, max_degree=40, cheb=False):
    coeffs, mse, pred = fit_pytorch(x, y, max_degree, cheb=cheb)
    degrees = np.arange(len(coeffs))
    abs_coeffs = np.abs(coeffs)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)

    return weighted_degree, mse, pred

def weighted_degree_pytorch(x, y, max_degree=40, cheb=False):
    coeffs, mse, pred = fit_pytorch(x, y, max_degree, cheb=cheb)
    degrees = np.arange(len(coeffs))
    abs_coeffs = np.abs(coeffs)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)

    return weighted_degree, mse, pred


def min_degree(x, y, max_degree = 20, threshold = 0.07):
    for degree in range(max_degree + 1):

        mse, y_pred = fit_121(x, y, degree)
        if mse <= threshold:
            return degree, mse, y_pred
    # print(all_BIC)
    bound_deg = max_degree + 1
    mse, y_pred = fit_121(x, y, bound_deg)
    return bound_deg, mse, y_pred


def best_degree(net, images, device, mode = 'AIC', max_degree = 40, resolution = 500, to_plot = False, it = 0):
    net.eval()
    # net.train()
    img_0 = images[0]
    img_1 = images[1]
    vec = img_1 - img_0
    
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  
    x = x.reshape(-1, 1)
    x_sample = (x + 1) / 2
    # x = torch.linspace(-0.1, 1.1, resolution)  # Uniform sampling
    x_sample = torch.from_numpy(x_sample).float().view(-1, 1, 1, 1)  # Reshape to (500,1,1,1)
    # Adjust dimensions for broadcasting
    vec = vec.unsqueeze(0)  # (1,3,32,32)
    img_0 = img_0.unsqueeze(0)                     # (1,3,32,32)
    # Generate image sequence
    img = vec * x_sample + img_0  # (500,3,32,32)
    # print(img[0])
    # save_image(img[0], 'output_image_1.png')
    # img_diff = img[0] - img[499]
    # x = x.squeeze()
    img = img.to(device)
    with torch.no_grad():
        # y_true = F.softmax(net(img), dim=1)  
        y_true = net(img)
        # output_diff = net(img[0].unsqueeze(0)) - net(img[499].unsqueeze(0))
    # x = np.linspace(-10, 110, resolution)
    # OLS
    num_classes = y_true.shape[1]
    # print(y_true[0, :])
    y_true = y_true.cpu()
    deg_all = []
    mse_all = []
    if to_plot:
        fig, axes = plt.subplots(2, 5, figsize=(20, 8))  
        
        plt.subplots_adjust(hspace=0.4, wspace=0.3)  
    
    for i in range(num_classes):
        # print(i)
        y_true_layer = y_true[:,i]
        if mode == 'AIC':
            deg, mse, y_pred = best_degree_AIC(x, y_true_layer, i=i)
            # print('using AIC!')
        elif mode == 'BIC':
            deg, mse, y_pred = best_degree_BIC(x, y_true_layer, i=i)
            # print('using BIC!')
        elif mode == 'min':
            deg, mse, y_pred = min_degree(x, y_true_layer)
        elif mode == 'cv':
            deg, mse, y_pred = best_degree_cross_validation_new(x, y_true_layer)
        elif mode == 'w':
            deg, mse, y_pred = weighted_degree(x, y_true_layer)
        elif mode == 'w_cheb':
            deg, mse, y_pred, _ = weighted_degree_chebyshev(x, y_true_layer, resolution, max_degree)
            # print(deg)
        elif mode == 'w_pytorch':
            x_ = torch.linspace(-1, 1, resolution)
            deg, mse, y_pred = weighted_degree_pytorch(x_, y_true_layer)
        else:
            print('wrong mode!')
        deg_all.append(deg)
        mse_all.append(mse)

        if to_plot:
            ax = axes[i//5, i%5] 
            ax.plot(x, y_true_layer, 'b-', linewidth=1.5, label='Original')
            ax.plot(x, y_pred, 'r--', linewidth=1.5, label=f'Fitted')
            ax.set_title(f'Dim {i} (Degree={deg:.4f} MSE={mse:.4f})')
            ax.set_xlabel('x')
            ax.set_ylabel('f(x)')
            ax.legend(loc='best')
            ax.grid(True, alpha=0.3)
            ax.axhline(0, color='black', linewidth=0.5)
            ax.axvline(0, color='black', linewidth=0.5)


    # deg_output = np.max(deg_all)
    deg_output = np.mean(deg_all)
    mse_output = np.mean(mse_all)
    if to_plot:
        # print(it)
        fig.suptitle(f'Polynomial Fit (Degree={deg_output}) (MSE: {mse_output:.4f})', fontsize=16)
        plot_dir = f'./poly/plots/{mode}'
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)
        plt.savefig(f'./poly/plots/{mode}/{args.net}_aug_vizsoftmax_no_reg_train.png') # !!!!
        plt.close(fig)  
    return deg_output, mse_output


def get_random_noise_image():
    """
    Generate a random noise image with uniform distribution and apply the same normalization as CIFAR-10.
    Simulate random pixel values in the range 0-255.
    """
    # 1. Generate uniformly distributed random numbers in [0, 1) (simulate the result of 0-255 pixels after ToTensor)
    # Size is (3, 32, 32) corresponding to CIFAR
    random_tensor = torch.rand(3, 32, 32)
    
    # 2. Define normalization (consistent with your transform_test)
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    
    # 3. Normalize
    return normalize(random_tensor)

def avg_best_degree(net, dataset, device, mode = 'AIC', repeats=100, max_degree=40, resolution=500, to_plot = False, use_noise=False):
    best_degrees = []
    if not use_noise:
        n_samples = len(dataset)
    
    for i in tqdm(range(repeats)):
        if use_noise:

            # Generate two random noise images
            img_pair = [get_random_noise_image(), get_random_noise_image()]
        else:
            # Original logic: randomly sample from the dataset
            idx1, idx2 = np.random.choice(n_samples, 2, replace=False)
            img_pair = [dataset[idx1][0], dataset[idx2][0]]
        deg, _ = best_degree(net, img_pair, device, mode, max_degree, resolution, False, i)
        best_degrees.append(deg)
    # print(best_degrees)

    if to_plot:
        cumulative_avg = []      # Store cumulative average results
        current_sum = 0          # Dynamically record the current sum

        for i, value in enumerate(best_degrees):
            current_sum += value  # Update the sum
            # Calculate the current average: sum / number of elements (i+1)
            cumulative_avg.append(current_sum / (i + 1))

        import matplotlib.pyplot as plt

        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(cumulative_avg) + 1), cumulative_avg, 'b-o', linewidth=2)
        plt.title(f'average best degree Evolution - {args.net}')
        plt.xlabel('Epoch')
        plt.ylabel('ABD')
        plt.grid(True, linestyle='--', alpha=0.7)

        # Mark stable point (e.g., when change is less than 1% for 5 consecutive epochs)
        stable_epoch = None
        window = 30

        # Save plot
        plot_dir = f'./poly/plots/{mode}'
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)
        plot_path = f'./poly/plots/{mode}/poly_evolution_{args.net}_{args.set_seed}.png'
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"Saved abd evolution plot to {plot_path}")







    # Calculate statistics
    avg_deg = np.mean(best_degrees)
    std_deg = np.std(best_degrees)
    median_deg = np.median(best_degrees)
    
    print(f"Rounds: {repeats}")
    print(f"Average Best Degree: {avg_deg:.2f} ± {std_deg:.2f}")
    print(f"Median: {median_deg}, Min: {min(best_degrees)}, Max: {max(best_degrees)}")
    
    return avg_deg







if __name__ == "__main__":
    args = options().parse_args()
    print(args)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.manual_seed(args.set_data_seed)
    trainloader, testloader = get_data(args)
    torch.manual_seed(args.set_seed)
    np.random.seed(args.set_seed)

    # get train set data
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    raw_trainset = torchvision.datasets.CIFAR10(
        root='~/data', train=True, download=True, transform=transform_test)
    raw_testset = torchvision.datasets.CIFAR10(
        root='~/data', train=False, download=True, transform=transform_test)
    # raw_trainset = torchvision.datasets.CIFAR10(
        # root='./datasets', train=True, download=True, transform=transform_test)

    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=100, shuffle=False, num_workers=2)


    if args.imgs:
        image_ids = args.imgs
        images = [raw_trainset[i][0] for i in image_ids]
        labels = [raw_trainset[i][1] for i in image_ids]
        print(labels)


    # Load model
    net = get_model(args, device)

    # Load pretrained model (if specified)
    if args.load_net:
        # Check if it's DataParallel
        if isinstance(net, torch.nn.DataParallel):
            net = net.module  # Remove DataParallel wrapper
        net.load_state_dict(torch.load(args.load_net))
        print(f"Loaded model from {args.load_net}")
    # print(fitAndPlot_1d(net, 12, images, device, 500, True))
    # print(min_degree(net, images, device))
    
    
    # print(best_degree(net, images, device, 'w_cheb', to_plot=True))







    abd = avg_best_degree(net, raw_testset, device, 'w_cheb', args.epochs, max_degree = args.max_degree, resolution=int(args.resolution), to_plot=False, use_noise=args.use_noise)
    gg = get_generalization_gap(args, net, testloader, device)
    data = {
            'net' : args.load_net,
            'abd' : abd,
            # 'correct' : np.mean(ctrain_ctrain),
            # 'test' : np.mean(alltest_alltest),
            # 'all_std' : np.std(alltrain_alltrain),
            # 'correct_std' : np.std(ctrain_ctrain),
            # 'test_std' : np.std(alltest_alltest)
            # 'ctrain_asgtest' : np.mean(ctrain_asgtest)
            # 'max_degree' : args.max_degree,
            # 'resolution' : args.resolution,
            'generalization gap': gg,
    }
    with open('./poly/awd.txt', 'a') as f:  
            f.write('\n' + str(data))  # Add newline before writing data