import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
# import torch.special as special
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

from tqdm import tqdm
from math import log

from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import RidgeCV, LinearRegression, Ridge, Lasso
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score, KFold
from sklearn.metrics import mean_squared_error
from sklearn.base import BaseEstimator, TransformerMixin
from numpy.polynomial.chebyshev import chebvander
from torchvision.utils import save_image
import clip
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm 
from utils import get_model_from_sd
import json

# warnings.filterwarnings("ignore", category=np.RankWarning)

def apply_pca(output_trajectory, n_components=5):
    """
    Performs PCA dimensionality reduction on a single interpolation trajectory.

    Args:
        output_trajectory: [resolution, num_classes] (e.g., [500, 1000])
                        The model outputs along the interpolation path.
        n_components: Target dimensionality after reduction (e.g., 5).

    Returns:
        projected_output: [resolution, n_components]
                        The projected trajectory in the reduced subspace.
    """
    # 1. Centering
    mean = torch.mean(output_trajectory, dim=0, keepdim=True)
    X_centered = output_trajectory - mean

    # 2. Compute principal components using SVD
    # X = U S V^T
    # Rows of Vh (V transpose) are the principal component directions
    with torch.no_grad():
        try:
            # full_matrices=False makes Vh have shape [min(N, D), D]
            U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
            # Take the first n_components principal components
            # Vh: [rank, num_classes], take the first n_components rows -> [n_components, num_classes]
            components = Vh[:n_components] 
        except RuntimeError:
            # If SVD fails (rare cases), use random projection or mean as fallback
            print("Warning: SVD failed, using fallback.")
            components = torch.randn(n_components, output_trajectory.size(1), device=output_trajectory.device)

    # 3. Project data
    # [resolution, num_classes] @ [num_classes, n_components] -> [resolution, n_components]
    # detach components to prevent gradient tracking (mainly for evaluation, usually no backprop needed)
    projected_output = torch.mm(X_centered, components.T)
    
    return projected_output

class Polynomial(nn.Module):

    def __init__(self, degree):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(torch.randn(degree+1))

    def forward(self, x):
        y = self.coeffs[0] * torch.ones(x.size(0)).to(device)
        for i in range(1, self.degree + 1):
            y += self.coeffs[i] * (x ** i)
        return y


class ChebyshevPolynomial(Polynomial):

    def __init__(self, degree):
        super().__init__(degree=degree)

    def forward(self, x):
        y = 0
        for i in range(self.degree + 1):
            y += self.coeffs[i] * special.chebyshev_polynomial_u(x, i)
        return y


def fit_pytorch(x, y, degree, iters=3000, cheb=False):
    x = x.to(device)
    y = y.to(device)
    if cheb:
        poly = ChebyshevPolynomial(degree).to(device)
    else:
        poly = Polynomial(degree).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(poly.parameters(), lr=0.1, weight_decay=0)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
    for it in range(iters):
        pred = poly(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (it + 1) % 100 == 0:
            print(f'Iter {it + 1}, loss = {loss:.6f}, coeffs = {poly.coeffs}')
    pred = poly(x)
    loss = criterion(pred, y)
    return poly.coeffs.detach().cpu().numpy(), loss.item(), pred.detach().cpu().numpy()


def fit_121(x, y, degree):
    coeffs = np.polyfit(x, y, degree)
    poly_func = np.poly1d(coeffs)
    y_pred = poly_func(x)
    mse = mean_squared_error(y, y_pred)
    return mse, y_pred


def best_degree_AIC(x, y, max_degree = 20, resolution = 500, visualize = True, i = 0):
    all_AIC = []
    all_mse = []
    all_y_pred = []
    n = resolution
    for degree in range(max_degree + 1):
        k = degree + 1
        mse, y_pred = fit_121(x, y, degree)
        all_mse.append(mse)
        all_y_pred.append(y_pred)
        AIC = 2 * k + n * log(mse)
        print(f'degree: {degree}\n2k: {2*k}, mse: {mse:.3f}, nln(mse): {n * log(mse):.3f}, AIC: {AIC:.3f}')
        all_AIC.append(AIC)
    # Comma-separated (4 decimal places)
    # formatted = ", ".join([f"{x:.4f}" for x in all_AIC])
    # print(formatted)  # example: 123.4567, 78.9012, 45.6789

    best_deg = np.argmin(all_AIC)
    if visualize:
        plt.figure(figsize=(12, 10))
        all_degrees = np.arange(max_degree + 1)
        all_2k = 2 * (all_degrees + 1)
        all_nlnmse = [n * log(mse) for mse in all_mse]
        plt.subplot(2, 2, 1)
        plt.plot(all_degrees, all_2k, 'bo-')
        plt.title('2k vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('2k')
        plt.grid(True)
        
        plt.subplot(2, 2, 2)
        plt.plot(all_degrees, all_mse, 'ro-')
        plt.title('MSE vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('MSE')
        plt.yscale('log')  # Use log scale for clarity
        plt.grid(True)
        
        plt.subplot(2, 2, 3)
        plt.plot(all_degrees, all_nlnmse, 'go-')
        plt.title('n*ln(MSE) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('n*ln(MSE)')
        plt.grid(True)
        
        plt.subplot(2, 2, 4)
        plt.plot(all_degrees, all_AIC, 'mo-')
        plt.scatter(best_deg, all_AIC[best_deg], s=100, c='red', zorder=5)
        plt.title(f'AIC vs Degree (Best: degree {best_deg})')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('AIC')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'./poly/plots/viz_{i}.png')
    return best_deg, all_mse[best_deg], all_y_pred[best_deg]

def best_degree_BIC(x, y, max_degree = 20, resolution = 500, visualize = False, i = 0):
    all_BIC = []
    all_mse = []
    all_y_pred = []
    n = resolution
    for degree in range(max_degree + 1):
        k = degree + 1
        mse, y_pred = fit_121(x, y, degree)
        all_mse.append(mse)
        all_y_pred.append(y_pred)
        BIC = log(n) * k + n * log(mse)
        all_BIC.append(BIC)
    # print(all_BIC)
    best_deg = np.argmin(all_BIC)
    if visualize:
        plt.figure(figsize=(12, 10))
        all_degrees = np.arange(max_degree + 1)
        all_klnn = [(degree + 1) * log(n) for degree in all_degrees]
        all_nlnmse = [n * log(mse) for mse in all_mse]
        plt.subplot(2, 2, 1)
        plt.plot(all_degrees, all_klnn, 'bo-')
        plt.title('kln(n) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('kln(n)')
        plt.grid(True)
        
        plt.subplot(2, 2, 2)
        plt.plot(all_degrees, all_mse, 'ro-')
        plt.title('MSE vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('MSE')
        plt.yscale('log')  
        plt.grid(True)
        
        plt.subplot(2, 2, 3)
        plt.plot(all_degrees, all_nlnmse, 'go-')
        plt.title('n*ln(MSE) vs Degree')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('n*ln(MSE)')
        plt.grid(True)
        
        plt.subplot(2, 2, 4)
        plt.plot(all_degrees, all_BIC, 'mo-')
        plt.scatter(best_deg, all_BIC[best_deg], s=100, c='red', zorder=5)
        plt.title(f'BIC vs Degree (Best: degree {best_deg})')
        plt.xlabel('Polynomial Degree')
        plt.ylabel('BIC')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'./poly/plots/viz_{i}_bic.png')
    return best_deg, all_mse[best_deg], all_y_pred[best_deg]


def numpy_split(lst, ratio):
    arr = np.array(lst)
    n = len(arr)
    indices = np.arange(0, n * ratio) * (1 / ratio)
    indices = np.unique(np.floor(indices).astype(int))
    mask = np.zeros(n, dtype=bool)
    mask[indices] = True
    return arr[mask].tolist(), arr[~mask].tolist()



def best_degree_cross_validation(x, y, max_degree=80, n_folds=5):
    """
    Optimized method for selecting polynomial degree
    """
    # Data standardization (key step)
    x_scaled = StandardScaler().fit_transform(x.reshape(-1, 1))
    
    # Use regularized cross-validation
    degrees = np.arange(max_degree)
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=args.set_seed)
    
    # Store average MSE for each degree
    degree_mses = []
    
    for degree in degrees:
        model = make_pipeline(
            PolynomialFeatures(degree=degree),
            LinearRegression()
            # RidgeCV(alphas = np.logspace(-3, 8, 12))  # Automatically select the best regularization strength
        )
        
        # Cross-validation
        mse_scores = -cross_val_score(
            model, x_scaled, y, 
            cv=kf, scoring='neg_mean_squared_error'
        )
        degree_mses.append(np.mean(mse_scores))
    
    # Select the best degree
    best_deg = degrees[np.argmin(degree_mses)]
    
    # Train the final model on the full dataset
    final_model = make_pipeline(
        PolynomialFeatures(degree=best_deg),
        LinearRegression()
        # RidgeCV(alphas=np.logspace(-6, 6, 13))
    ).fit(x_scaled, y)
    
    # Get prediction results
    y_pred = final_model.predict(x_scaled)
    best_mse = mean_squared_error(y, y_pred)
    
    return best_deg, best_mse, y_pred


def best_degree_cross_validation_new(x, y, max_degree=40, n_folds=5):
    # Only reshape the original features (standardization is done inside the Pipeline)
    
    x = x.reshape(-1, 1)
    # print(x)
    # Identify indices of data points in three regions
    left_bound = -0.085
    right_bound = 1.085
    left_val_mask = (x.ravel() >= -0.1) & (x.ravel() < left_bound)    # Left boundary region [-0.1, 0)
    mid_mask = (x.ravel() >= left_bound) & (x.ravel() < right_bound)             # Middle region [0, 1)
    x_mid = x[mid_mask]
    y_mid = y[mid_mask]
    right_val_mask = (x.ravel() >= right_bound) & (x.ravel() <= 1.1)    # Right boundary region [1, 1.1]
    
    left_val_indices = np.where(left_val_mask)[0]
    mid_indices = np.where(mid_mask)[0]
    right_val_indices = np.where(right_val_mask)[0]
    
    # x = np.linspace(-10, 10, 500).reshape(-1, 1)

    # Create custom cross-validation generator
    def custom_cv():
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=args.set_seed)
        for train_index_mid, test_index_mid in kf.split(mid_indices):
            # Training indices for the middle region
            train_index = mid_indices[train_index_mid]
            
            # Validation set = test part of the middle region + left and right boundaries
            test_index = np.concatenate([
                mid_indices[test_index_mid],
                left_val_indices,
                right_val_indices
            ])
            yield train_index, test_index
    degrees = np.arange(max_degree + 1)
    degree_mses = []
    
    for degree in degrees:
        # Correct Pipeline: only keep one estimator
        model = make_pipeline(
            PolynomialFeatures(degree=degree, include_bias=True),
            StandardScaler(),  # Must uncomment
            LinearRegression()
            # Ridge(alpha=alpha)   # Replace LinearRegression
            # Lasso(alpha=alpha, max_iter=10000)
        )
        
        mse_scores = -cross_val_score(
            model, x, y, 
            cv=custom_cv(), scoring='neg_mean_squared_error'
        )
        degree_mses.append(np.mean(mse_scores))
    # print(degree_mses)
    best_deg = degrees[np.argmin(degree_mses)]
    
    # Final model uses the same structure
    final_model = make_pipeline(
        PolynomialFeatures(degree=best_deg, include_bias=True),
        StandardScaler(),
        # Ridge(alpha=alpha)
        # Lasso(alpha=alpha, max_iter=10000)
        LinearRegression()
    ).fit(x, y)
    
    # The way to get coefficients needs adjustment
    # Get model coefficients
    lr_model = final_model.named_steps['linearregression']
    coefficients = [lr_model.intercept_] + lr_model.coef_.ravel().tolist()[1:]
    # print(coefficients)
    y_pred = final_model.predict(x)
    best_mse = mean_squared_error(y, y_pred)
    
    return best_deg, best_mse, y_pred


def weighted_degree(x, y, max_degree=40):
    x = x * 5
    x = x.reshape(-1, 1)
    model = make_pipeline(
        PolynomialFeatures(degree=max_degree, include_bias=False),
        StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    
    # Get the steps in the pipeline
    poly = model.named_steps['polynomialfeatures']
    scaler = model.named_steps['standardscaler']
    lr_model = model.named_steps['linearregression']
    
    # Get the standardized coefficients and intercept
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_
    
    # Calculate original coefficients
    # Reverse transformation based on standardization: coef_original = coef_scaled / scaler.scale_
    coef_original = coef_scaled / scaler.scale_
    # print('after std:', coef_scaled)
    # print('scale:', scaler.scale_)
    # print('original:', coef_original)
    
    # Calculate original intercept
    # intercept_original = intercept_scaled - np.sum(coef_scaled * scaler.mean_ / scaler.scale_)
    intercept_original = intercept_scaled - np.dot(coef_scaled, scaler.mean_ / scaler.scale_)
    
    # Combine intercept and coefficients
    coefficients_std = np.concatenate([[intercept_scaled], coef_scaled])
    coefficients_original = np.concatenate([[intercept_original], coef_original])
    # print('after std:', coefficients_std)
    # print('original:', coefficients_original)

    # Calculate weighted degree (using original coefficients)
    squre = np.square(coefficients_original)
    c_norm_2 = np.sum(squre)
    degrees = np.arange(max_degree + 1)
    weighted_degree = np.dot(squre, degrees) / c_norm_2
    
    # print('weighted degree:', weighted_degree)
    us_degree = int(weighted_degree) + 1
    model_weighted = make_pipeline(
        PolynomialFeatures(degree=us_degree, include_bias=False),
        StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    y_pred = model_weighted.predict(x)
    mse = mean_squared_error(y, y_pred)
    # y_pred_pipeline = model.predict(x)
    
    # # Calculate predictions directly using original coefficients (for validation)
    # # Generate polynomial features (excluding bias term)
    # X_poly = poly.transform(x)
    
    # # Add bias term (a column of ones)
    # X_poly_with_bias = np.column_stack([np.ones(X_poly.shape[0]), X_poly])
    # # print('x poly:', X_poly_with_bias)
    # # Calculate predictions using original coefficients
    # y_pred_original = X_poly_with_bias.dot(coefficients_original)
    # # print('y:', y_pred_original)
    # # Validate whether the predictions from the two methods are consistent (should be very close)
    # mse_validation = mean_squared_error(y_pred_pipeline, y_pred_original)
    # print(f"Validation MSE (pipeline prediction vs original coefficients prediction): {mse_validation}")
    
    # # Calculate final MSE (using pipeline prediction)
    # mse = mean_squared_error(y, y_pred_pipeline)
        
    return weighted_degree, mse, y_pred


def weighted_degree_chebyshev(x, y, resolution=500, max_degree=40):
    # x = np.linspace(-1, 1, resolution)
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  # Flip the array to go from smallest to largest
    x = x.reshape(-1, 1)
    
    # Custom Chebyshev feature transformer
    class ChebyshevFeatures(BaseEstimator, TransformerMixin):
        def __init__(self, degree):
            self.degree = degree
            
        def fit(self, X, y=None):
            return self
            
        def transform(self, X):
            # Generate features using Chebyshev polynomials
            return chebvander(X.flatten(), self.degree)[:, 1:]  # Exclude the constant term because linear regression has an intercept
    # restore start
    # Create a pipeline using Chebyshev polynomials
    model = make_pipeline(
        ChebyshevFeatures(degree=max_degree),
        # StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    
    # Get the steps in the pipeline
    # cheby = model.named_steps['chebyshevfeatures']

    lr_model = model.named_steps['linearregression']
    
    # Get the coefficients and intercept from the linear regression model
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_

    coefficients = np.concatenate([[intercept_scaled], coef_scaled])
    # print('coefficients chebyshev:', coefficients)
    
    # # Calculate original coefficients
    # coef_original = coef_scaled / scaler.scale_
    
    # # Calculate original intercept
    # intercept_original = intercept_scaled - np.dot(coef_scaled, scaler.mean_ / scaler.scale_)
    
    # # Combine intercept and coefficients
    # coefficients_std = np.concatenate([[intercept_scaled], coef_scaled])
    # coefficients_original = np.concatenate([[intercept_original], coef_original])
    
    # Calculate weighted degree (using original coefficients)
    # squre = np.square(coefficients)
    # c_norm_2 = np.sum(squre)
    degrees = np.arange(max_degree + 1)
    # weighted_degree = np.dot(squre, degrees) / c_norm_2
    abs_coeffs = np.abs(coefficients)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)
    # us_degree = int(weighted_degree) + 1
    # # print('weighted degree:', weighted_degree)
    # # Create final model using Chebyshev basis
    # model_weighted = make_pipeline(
    #     ChebyshevFeatures(degree=us_degree),
    #     # StandardScaler(),
    #     LinearRegression()
    # ).fit(x, y)
    # y_pred = model_weighted.predict(x)
    # mse = mean_squared_error(y, y_pred)
    y_pred = model.predict(x)
    mse = mean_squared_error(y, y_pred)
    return weighted_degree, mse, y_pred, coefficients
    # return 0, 0, y

def weighted_degree_chebyshev_no_norm(x, y, resolution=500, max_degree=40):
    # x = np.linspace(-1, 1, resolution)
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  # Flip the array to go from smallest to largest
    x = x.reshape(-1, 1)
    
    # Custom Chebyshev feature transformer
    class ChebyshevFeatures(BaseEstimator, TransformerMixin):
        def __init__(self, degree):
            self.degree = degree
            
        def fit(self, X, y=None):
            return self
            
        def transform(self, X):
            # Generate features using Chebyshev polynomials
            return chebvander(X.flatten(), self.degree)[:, 1:]  # Exclude the constant term because linear regression has an intercept
    # restore start
    # Create a pipeline using Chebyshev polynomials
    model = make_pipeline(
        ChebyshevFeatures(degree=max_degree),
        # StandardScaler(),
        LinearRegression()
    ).fit(x, y)
    
    # Get the steps in the pipeline
    # cheby = model.named_steps['chebyshevfeatures']

    lr_model = model.named_steps['linearregression']
    
    # Get the coefficients and intercept from the linear regression model
    coef_scaled = lr_model.coef_
    intercept_scaled = lr_model.intercept_

    coefficients = np.concatenate([[intercept_scaled], coef_scaled])
    # print('coefficients chebyshev:', coefficients)
    
    # # Calculate original coefficients
    # coef_original = coef_scaled / scaler.scale_
    
    # # Calculate original intercept
    # intercept_original = intercept_scaled - np.dot(coef_scaled, scaler.mean_ / scaler.scale_)
    
    # # Combine intercept and coefficients
    # coefficients_std = np.concatenate([[intercept_scaled], coef_scaled])
    # coefficients_original = np.concatenate([[intercept_original], coef_original])
    
    # Calculate weighted degree (using original coefficients)
    # squre = np.square(coefficients)
    # c_norm_2 = np.sum(squre)
    degrees = np.arange(max_degree + 1)
    # weighted_degree = np.dot(squre, degrees) / c_norm_2
    abs_coeffs = np.abs(coefficients)
    weighted_degree = np.dot(abs_coeffs, degrees)
    # us_degree = int(weighted_degree) + 1
    # # print('weighted degree:', weighted_degree)
    # # Create final model using Chebyshev basis
    # model_weighted = make_pipeline(
    #     ChebyshevFeatures(degree=us_degree),
    #     # StandardScaler(),
    #     LinearRegression()
    # ).fit(x, y)
    # y_pred = model_weighted.predict(x)
    # mse = mean_squared_error(y, y_pred)
    y_pred = model.predict(x)
    mse = mean_squared_error(y, y_pred)
    return weighted_degree, mse, y_pred, coefficients
    # return 0, 0, y



class Polynomial(nn.Module):

    def __init__(self, degree):
        super().__init__()
        self.degree = degree
        self.coeffs = nn.Parameter(torch.randn(degree+1))

    def forward(self, x):
        y = self.coeffs[0] * torch.ones(x.size(0)).to(device)
        for i in range(1, self.degree + 1):
            y += self.coeffs[i] * (x ** i)
        return y


class ChebyshevPolynomial(Polynomial):

    def __init__(self, degree):
        super().__init__(degree=degree)

    def forward(self, x):
        y = 0
        for i in range(self.degree + 1):
            y += self.coeffs[i] * special.chebyshev_polynomial_u(x, i)
        return y


def fit_pytorch(x, y, degree, iters=3000, cheb=False):
    x = x.to(device)
    y = y.to(device)
    if cheb:
        poly = ChebyshevPolynomial(degree).to(device)
    else:
        poly = Polynomial(degree).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(poly.parameters(), lr=0.1, weight_decay=0)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
    for it in range(iters):
        pred = poly(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if (it + 1) % 100 == 0:
            print(f'Iter {it + 1}, loss = {loss:.6f}, coeffs = {poly.coeffs}')
    pred = poly(x)
    loss = criterion(pred, y)
    return poly.coeffs.detach().cpu().numpy(), loss.item(), pred.detach().cpu().numpy()

def weighted_degree_pytorch(x, y, max_degree=40, cheb=False):
    coeffs, mse, pred = fit_pytorch(x, y, max_degree, cheb=cheb)
    degrees = np.arange(len(coeffs))
    abs_coeffs = np.abs(coeffs)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)

    return weighted_degree, mse, pred

def weighted_degree_pytorch(x, y, max_degree=40, cheb=False):
    coeffs, mse, pred = fit_pytorch(x, y, max_degree, cheb=cheb)
    degrees = np.arange(len(coeffs))
    abs_coeffs = np.abs(coeffs)
    weighted_degree = np.dot(abs_coeffs, degrees) / np.sum(abs_coeffs)

    return weighted_degree, mse, pred


def min_degree(x, y, max_degree = 20, threshold = 0.07):
    for degree in range(max_degree + 1):

        mse, y_pred = fit_121(x, y, degree)
        if mse <= threshold:
            return degree, mse, y_pred
    # print(all_BIC)
    bound_deg = max_degree + 1
    mse, y_pred = fit_121(x, y, bound_deg)
    return bound_deg, mse, y_pred


def best_degree(net, images, device, mode = 'AIC', max_degree = 40, resolution = 500, to_plot = False, pca_dim=1, it = 0, args=None):
    net.eval()
    # net.train()
    img_0 = images[0]
    img_1 = images[1]
    vec = img_1 - img_0
    
    k = np.arange(1, resolution + 1)
    x = np.cos((2 * k - 1) * np.pi / (2 * resolution))
    x = np.flip(x)  # Flip the array to go from smallest to largest
    x = x.reshape(-1, 1)
    x_sample = (x + 1) / 2
    # x = torch.linspace(-0.1, 1.1, resolution)  # Uniform sampling
    x_sample = torch.from_numpy(x_sample).float().view(-1, 1, 1, 1)  # Reshape to (500,1,1,1)
    # Adjust dimensions to support broadcasting
    vec = vec.unsqueeze(0)  # (1,3,32,32)
    img_0 = img_0.unsqueeze(0)                     # (1,3,32,32)
    # Generate image sequence
    img = vec * x_sample + img_0  # (500,3,32,32)
    # print(img[0])
    # save_image(img[0], 'output_image_1.png')
    # img_diff = img[0] - img[499]
    # x = x.squeeze()
    img = img.to(device)
    with torch.no_grad():
        if args.softmax:
            y_true = F.softmax(net(img), dim=1)  # Calculate true function values
        else:
            y_true = net(img)
        # print(y_true.shape)
        # output_diff = net(img[0].unsqueeze(0)) - net(img[499].unsqueeze(0))
    # Normalize to -5 and 5
    # x = np.linspace(-10, 110, resolution)
    # OLS
    y_projected = apply_pca(y_true, n_components=pca_dim) # [resolution, pca_dim]
    y_projected_np = y_projected.cpu()
    
    # Loop over the projected dimensions
    num_eval_dims = y_projected.shape[1]
    # print(y_true[0, :])
    y_true = y_true.cpu()
    deg_all = []
    mse_all = []
    # Create subplot layout
    if to_plot:
        fig, axes = plt.subplots(2, 5, figsize=(20, 8))  # 2 rows 5 columns layout
        
        plt.subplots_adjust(hspace=0.4, wspace=0.3)  # Adjust subplot spacing
    
    for i in range(num_eval_dims):
        # print(i)
        y_true_layer = y_projected_np[:,i]
        if mode == 'AIC':
            deg, mse, y_pred = best_degree_AIC(x, y_true_layer, i=i)
            # print('using AIC!')
        elif mode == 'BIC':
            deg, mse, y_pred = best_degree_BIC(x, y_true_layer, i=i)
            # print('using BIC!')
        elif mode == 'min':
            deg, mse, y_pred = min_degree(x, y_true_layer)
        elif mode == 'cv':
            deg, mse, y_pred = best_degree_cross_validation_new(x, y_true_layer)
        elif mode == 'w':
            deg, mse, y_pred = weighted_degree(x, y_true_layer)
        elif mode == 'w_cheb':
            if args.norm:
                deg, mse, y_pred, _ = weighted_degree_chebyshev(x, y_true_layer, resolution, max_degree)
            else:
                deg, mse, y_pred, _ = weighted_degree_chebyshev_no_norm(x, y_true_layer, resolution, max_degree)
            # print(deg)
        elif mode == 'w_pytorch':
            # Normalize to [-1, 1]
            x_ = torch.linspace(-1, 1, resolution)
            deg, mse, y_pred = weighted_degree_pytorch(x_, y_true_layer)
        else:
            print('wrong mode!')
        # Calculate MSE
        deg_all.append(deg)
        mse_all.append(mse)

        # Plot subplots
        if to_plot:
            ax = axes[i//5, i%5]  # Determine subplot position
            ax.plot(x, y_true_layer, 'b-', linewidth=1.5, label='Original')
            ax.plot(x, y_pred, 'r--', linewidth=1.5, label=f'Fitted')
            ax.set_title(f'Dim {i} (Degree={deg:.4f} MSE={mse:.4f})')
            ax.set_xlabel('x')
            ax.set_ylabel('f(x)')
            ax.legend(loc='best')
            ax.grid(True, alpha=0.3)
            ax.axhline(0, color='black', linewidth=0.5)
            ax.axvline(0, color='black', linewidth=0.5)


    # deg_output = np.max(deg_all)
    deg_output = np.mean(deg_all)
    mse_output = np.mean(mse_all)
    # Save and display
    if to_plot:
        # print(it)
        fig.suptitle(f'Polynomial Fit (Degree={deg_output}) (MSE: {mse_output:.4f})', fontsize=16)
        plot_dir = f'./poly/plots/{mode}'
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)
        plt.savefig(f'./poly/plots/{mode}/{args.net}_aug_vizsoftmax_no_reg_train.png') # !!!!
        plt.close(fig)  # Close the figure to free memory
    return deg_output, mse_output


def avg_best_degree(net, dataset, device, mode = 'AIC', repeats=100, max_degree=40, resolution=500, to_plot = False, pca_dim = 1, args=None):
    best_degrees = []
    n_samples = len(dataset)
    
    for i in tqdm(range(repeats)):
        # Randomly select two different image indices
        idx1, idx2 = np.random.choice(n_samples, 2, replace=False)
        # Get image pair
        img_pair = [dataset[idx1][0], dataset[idx2][0]]
        # Calculate minimum degree
        # deg = min_degree(net, img_pair, device, threshold, max_degree, resolution)
        # deg = best_degree_AIC(net, img_pair, device, max_degree, resolution)
        deg, _ = best_degree(net, img_pair, device, mode, max_degree, resolution, False, pca_dim, i, args)
        best_degrees.append(deg)
    # print(best_degrees)

    if to_plot:
        cumulative_avg = []      # Store cumulative average results
        current_sum = 0          # Dynamically track current sum

        for i, value in enumerate(best_degrees):
            current_sum += value  # Update sum
            # Calculate current average: sum / number of elements (i+1)
            cumulative_avg.append(current_sum / (i + 1))

        import matplotlib.pyplot as plt

        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(cumulative_avg) + 1), cumulative_avg, 'b-o', linewidth=2)
        plt.title(f'average best degree Evolution')
        plt.xlabel('Epoch')
        plt.ylabel('ABD')
        plt.grid(True, linestyle='--', alpha=0.7)

        # Mark stable point (e.g., when change is less than 1% for 5 consecutive epochs)
        stable_epoch = None
        window = 30

        # Save plot
        plot_dir = f'./poly/plots/{mode}'
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)
        plot_path = f'./poly/plots/poly_evolution.png'
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"Saved abd evolution plot to {plot_path}")







    # Calculate statistics
    avg_deg = np.mean(best_degrees)
    std_deg = np.std(best_degrees)
    median_deg = np.median(best_degrees)
    
    print(f"Rounds: {repeats}")
    print(f"Average Minimum Degree: {avg_deg:.2f} ± {std_deg:.2f}")
    print(f"Median: {median_deg}, Min: {min(best_degrees)}, Max: {max(best_degrees)}")
    
    return avg_deg







def parse_arguments():
    parser = argparse.ArgumentParser(description='Weighted Degree evaluation for CLIP')
    
    # Paths and model parameters
    parser.add_argument('--data_location', type=str, default='/newdata_nvme/datasets/xxx/imagenet',
                        help='Root directory for ImageNet')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Path to the fine-tuned model .pt file')
    parser.add_argument('--split', type=str, default='train', choices=['train', 'val'],
                        help='Evaluate on "train" or "val" set')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='Path to save wd results')
    # Algorithm parameters
    parser.add_argument('--repeats', type=int, default=50, help='Number of image pairs to evaluate')
    parser.add_argument('--resolution', type=int, default=200, help='Interpolation steps')
    parser.add_argument('--max_degree', type=int, default=40, help='Max polynomial degree')
    parser.add_argument('--pca_dim', type=int, default=1, help='dim after pca')
    parser.add_argument('--softmax', action='store_true', help='fit softmax landscape')
    parser.add_argument('--norm', action='store_true', help='use norm in cheb fit')
    parser.add_argument('--clip_backbone', type=str, default='ViT-B/32', help='CLIP backbone')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')


    return parser.parse_args()

def main():
    args = parse_arguments()
    print(args)

    # 1. Environment setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # 2. Load model
    print(f"Loading CLIP base model: {args.clip_backbone}...")
    base_model, preprocess = clip.load(args.clip_backbone, device='cpu', jit=False)
    
    print(f"Loading weights from {args.model_path}...")
    if os.path.exists(args.model_path):
        state_dict = torch.load(args.model_path, map_location='cpu')
        # Handle possible nested dict (e.g., state_dict['state_dict']), if unsure, you can directly pass state_dict
        if isinstance(state_dict, dict) and 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
            
        model = get_model_from_sd(state_dict, base_model)
    else:
        print(f"Error: Model file {args.model_path} not found.")
        return

    model = model.to(device)

    # 3. Prepare data
    # Choose train or val folder based on args.split
    subdir = 'train' if args.split == 'train' else 'val'
    data_dir = os.path.join(args.data_location, subdir)
    
    # Compatible with common ImageNet ILSVRC2012 naming conventions
    if not os.path.exists(data_dir):
        alt_name = 'ILSVRC2012_img_train' if args.split == 'train' else 'ILSVRC2012_img_val'
        alt_dir = os.path.join(args.data_location, alt_name)
        if os.path.exists(alt_dir):
            data_dir = alt_dir
        else:
            # If neither exists, try reading the root directory directly (some datasets have different structures)
            print(f"Warning: {data_dir} not found. Trying root...")
            data_dir = args.data_location

    print(f"Loading {args.split} data from: {data_dir}")
    
    try:
        # Use CLIP's built-in preprocessing
        dataset = ImageFolder(root=data_dir, transform=preprocess)
        print(f"Found {len(dataset)} images, {len(dataset.classes)} classes.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    # 4. Wrap model
    # Convert CLIP to a standard classifier interface (Input Image -> Output Logits)
    # This is required for the weighted_degree algorithm
    

    # 5. Run evaluation
    print("Starting Weighted Degree Evaluation...")
    
    # Assume avg_best_degree is defined elsewhere in the script
    # The parameter mode='w_cheb' corresponds to your previous Chebyshev fitting logic
    abd = avg_best_degree(
        net=model,
        dataset=dataset,
        device=device,
        mode='w_cheb', 
        repeats=args.repeats,
        max_degree=args.max_degree,
        resolution=args.resolution,
        to_plot=False, # By default, only plot the first pair of samples
        pca_dim=args.pca_dim,
        args=args
    )
    
    print("="*30)
    print(f"Results for {os.path.basename(args.model_path)} on {args.split} set:")
    print(f"Average Best Degree (ABD): {abd:.4f}")
    print("="*30)

    # 4. Save results
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        
    # Change filename extension to .jsonl
    output_filename = os.path.join(args.output_dir, f"wd_softmax_{args.softmax}_norm_{args.norm}_results_pca_{args.pca_dim}_{args.split}.jsonl")
    
    model_name = os.path.basename(args.model_path)
    
    # Construct result dictionary
    result_entry = {
        "model_name": model_name,
        "abd": float(f"{abd:.6f}"),  # Keep 6 decimal places and convert back to float, or directly store abd
        "pca_dim": args.pca_dim,
        "split": args.split,
        # "dataset_size": len(dataset) # Optional: record dataset size
    }
    
    # Append results in JSONL format
    with open(output_filename, "a") as f:
        f.write(json.dumps(result_entry) + "\n")
    
    print(f"Saved result to {output_filename}")

if __name__ == '__main__':
    main()