import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args()
batch_size = args.batch_size
import adelie as ad4
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import scipy.io as sio
import os
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib.ticker import MaxNLocator
from tqdm import tqdm
#set seed to ensure reproducibility
np.random.seed(0)
torch.manual_seed(0)
housing_path = "housing.csv"
train_prop = 0.8    # proportion of training set size
seed = 0            # random seed for data splitting
housing = np.loadtxt(housing_path)
X_full = housing[:, :-1]
y_full = housing[:, -1]

# train-test split
n = X_full.shape[0]
np.random.seed(0)
order = np.random.choice(n, n, replace=False)
tr_order = order[:int(train_prop * n)]
ts_order = order[int(train_prop * n):]
X_tr, y_tr = X_full[tr_order], y_full[tr_order]
X_ts, y_ts = X_full[ts_order], y_full[ts_order]

# standardize
X_tr_centers = np.mean(X_tr, axis=0)
X_tr_scales = np.std(X_tr, axis=0, ddof=0)
X_tr = (X_tr - X_tr_centers[None]) / X_tr_scales[None]
X_ts = (X_ts - X_tr_centers[None]) / X_tr_scales[None]
y_tr_center = np.mean(y_tr)
y_tr -= y_tr_center
y_ts -= y_tr_center



def train_nonconvex(X_tr, y_tr, X_ts, y_ts, num_epochs=100, batch_size=32, learning_rate=0.1, weight_decay = 0.01, hidden_neurons = 1000, verbose=False):
    #train a 2 layer MLP using adamw on the data X_tr, y_tr and test on X_ts, y_ts
    # Convert to PyTorch tensors
    X_tr = torch.tensor(X_tr, dtype=torch.float32)
    y_tr = torch.tensor(y_tr, dtype=torch.float32).view(-1, 1)
    X_ts = torch.tensor(X_ts, dtype=torch.float32)
    y_ts = torch.tensor(y_ts, dtype=torch.float32).view(-1, 1)
    #move all to cuda
    X_tr = X_tr.cuda()
    y_tr = y_tr.cuda()
    X_ts = X_ts.cuda()
    y_ts = y_ts.cuda()
    # Define the 2-layer MLP model
    class MLP2(nn.Module):
        def __init__(self, hidden_neurons):
            super(MLP2, self).__init__()
            self.layer1 = nn.Linear(X_tr.shape[1], hidden_neurons)
            self.layer2 = nn.Linear(hidden_neurons, 1)

        def forward(self, x):
            x = torch.relu(self.layer1(x))
            x = self.layer2(x)
            return x
    class MLP3(nn.Module):
        def __init__(self, hidden_neurons):
            super(MLP3, self).__init__()
            self.layer1 = nn.Linear(X_tr.shape[1], hidden_neurons)
            self.layer2 = nn.Linear(hidden_neurons, hidden_neurons)
            self.layer3 = nn.Linear(hidden_neurons, 1)

        def forward(self, x):
            x = torch.relu(self.layer1(x))
            x = torch.relu(self.layer2(x))
            x = self.layer3(x)
            return x

    # Training configurations

    # Training and plotting
    #fig, ax = plt.subplots(figsize=(12, 8))
    #colors = plt.cm.viridis(np.linspace(0, 1, len(hidden_neurons_list)))

    #for idx, hidden_neurons in enumerate(hidden_neurons_list):
    net = MLP2(hidden_neurons)
    #move to cuda
    net = net.cuda()
    criterion = nn.MSELoss()
    #sgd optimizer
    #optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        net.train()
        permutation = torch.randperm(X_tr.size()[0])
        epoch_loss = 0

        for i in range(0, X_tr.size()[0], batch_size):
            optimizer.zero_grad()

            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X_tr[indices], y_tr[indices]

            outputs = net(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        train_losses.append(epoch_loss / ((X_tr.size()[0] + batch_size - 1) // batch_size))

        net.eval()
        with torch.no_grad():
            val_outputs = net(X_ts)
            val_loss = criterion(val_outputs, y_ts)
            val_losses.append(val_loss.item())
        if verbose:
            print(f'Hidden Neurons [{hidden_neurons}], Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Validation Loss: {val_losses[-1]:.4f}')

        #ax.plot(np.log(train_losses), label=f'Train Loss (Hidden: {hidden_neurons})', color=colors[idx])
        #ax.plot(np.log(val_losses), label=f'Validation Loss (Hidden: {hidden_neurons})', linestyle='--', color=colors[idx])
    return train_losses, val_losses, net
    # ax.set_xlabel('Epochs')
    # ax.set_ylabel('Log Loss')
    # ax.yaxis.set_major_locator(MaxNLocator(nbins=20))  # Set the number of y-axis ticks
    # ax.legend()
    # ax.set_title('Training and Validation Loss for Different Hidden Neurons')
    # plt.show()
    # print(val_losses[-1])
    # todo add polishing.
    # todo add geometric algebra lasso, add deeper networks.



def transform_U(U, X):
    """
    For each row in U, find the (d-1) rows in X with the lowest inner-product magnitude,
    then replace the row in U with the normal of the hyperplane through the origin and those points.
    """
    d = U.shape[1]  # Dimensions
    transformed_U = np.zeros_like(U)
    #for i, u_row in enumerate(U):
    for i, u_row in tqdm(enumerate(U), total=U.shape[0], desc="Processing neurons"):
        u_row_normalized = u_row# / np.linalg.norm(u_row)
        inner_products = [np.abs(np.dot(u_row_normalized, x_row)) for x_row in X]#/ np.linalg.norm(x_row)
        #add debug point
        # Get indices of the (d-1) smallest inner products
        smallest_indices = np.argsort(inner_products)[:d-1]
        
        # Get the (d-1) rows from X
        selected_rows = X[smallest_indices]
        
        # Find normal of hyperplane passing through the origin and these rows
        normal = find_normal_of_hyperplane(selected_rows)
        
        # Replace the current row of U with this normal
        transformed_U[i] = normal
    
    return transformed_U
def find_normal_of_hyperplane(points):
    """
    Find the normal vector of the hyperplane passing through the origin and the given (d-1) points.
    Assumes points is a (d-1)xN matrix, where each row is a point in d-dimensional space.
    """
    method = 'eigh'
    # Use SVD to find the null space of the matrix formed by points
    if method == 'svd':
        u, s, vh = np.linalg.svd(points, full_matrices=True)
        #d = u.shape[1]
        # MP: change this to rank k svd not the full svd.
        # problem: smallest singular value is zero for MNIST.
        # The normal vector is the last column of vh, corresponding to the smallest singular value
        normal = vh[-1]
    else:
        G = np.dot(points.T, points)
        # Compute the eigenvalues and eigenvectors
        eigenvalues, eigenvectors = np.linalg.eigh(G)
        # The smallest eigenvalue's corresponding eigenvector
        normal = eigenvectors[:, 0]
    normalized_normal = normal / np.linalg.norm(normal)
    
    return normalized_normal
def relu(x):
    return np.maximum(0,x)

def polish_two_layer(X_tr, y_tr, X_ts, y_ts, net, weight_decay = 0.01, learning_rate=0.1, batch_size=32, num_epochs=20, verbose=False):
    Xdata = X_tr #torch.cat(data_list, dim=0)
    net = net.to('cuda')
    ydata = y_tr
    #convert ydata to numpy array
    #ydata = ydata.detach().cpu().numpy()
    #Xdata = Xdata.detach().cpu().numpy()
    Uorg = net.cpu().layer1.weight.detach().numpy().T
    new_U = transform_U(np.copy(Uorg).T,Xdata).T
    #new_U = np.copy(Uorg) ## mp sanity check
    A = relu(Xdata@new_U)
    AtA = A.T@A
    beta = weight_decay
    #beta = 1e-1
    ls_method = 'pseudoinverse'
    energy_percentage = 1
    if ls_method == 'pseudoinverse':
        w_ls = np.linalg.pinv(AtA + beta*np.identity(AtA.shape[0]))@(A.T@ydata)
    else: #ls_method == 'truncated': #truncate the svd to 99 percent of the energy
        u, s, vh = np.linalg.svd(AtA, full_matrices=True)
        # Normalize the singular values
        sn = s/s.sum()
        # Truncate singular values
        idx_trunc = np.where(sn.cumsum() <= energy_percentage)[0]
        # Form truncated AtA
        AtA_trunc = u[:,idx_trunc] @ (np.diag(s[idx_trunc]) @ u[:,idx_trunc].T)
        w_ls = np.linalg.pinv(AtA_trunc + A.shape[0]*beta*np.identity(AtA.shape[0]))@(A.T@ydata)
    net.layer1.weight.data = torch.tensor(new_U.T)



    retrain_second_layer_as_Lasso = False

    if retrain_second_layer_as_Lasso:
        state = ad.grpnet(
        X=A,
        glm=ad.glm.gaussian(y=y_tr.flatten()),
        )
        #convert state.betas from sparse format to dense numpy array
        net.layer2.weight.data = torch.tensor(state.betas[-1,:].toarray())
        
    retrain_second_layer = True
    if retrain_second_layer:
        # move all to torch
        X_tr = torch.tensor(X_tr, dtype=torch.float32)
        y_tr = torch.tensor(y_tr, dtype=torch.float32).view(-1, 1)
        X_ts = torch.tensor(X_ts, dtype=torch.float32)
        y_ts = torch.tensor(y_ts, dtype=torch.float32).view(-1, 1)
        X_tr = X_tr.cuda()
        y_tr = y_tr.cuda()
        X_ts = X_ts.cuda()
        y_ts = y_ts.cuda()
        net = net.cuda()
        net.layer1.weight.requires_grad = False
        net.layer1.bias.requires_grad = False
        net.layer2.weight.requires_grad = True
        net.layer2.bias.requires_grad = True
        criterion = nn.MSELoss()
        #optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
        optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
        train_losses = []
        val_losses = []
        for epoch in range(num_epochs):
            net.train()
            permutation = torch.randperm(X_tr.size()[0])
            epoch_loss = 0

            for i in range(0, X_tr.size()[0], batch_size):
                optimizer.zero_grad()

                indices = permutation[i:i+batch_size]
                batch_x, batch_y = X_tr[indices], y_tr[indices]
                outputs = net(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            train_losses.append(epoch_loss / ((X_tr.size()[0] + batch_size - 1) // batch_size))
            
            net.eval()
            with torch.no_grad():
                val_outputs = net(X_ts)
                val_loss = criterion(val_outputs, y_ts)
                val_losses.append(val_loss.item())
            if verbose:
                print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Validation Loss: {val_losses[-1]:.4f}')
    return train_losses, val_losses, net

#train the network and polish its
num_neurons = np.round(np.linspace(200,1000,20)).astype(int)
# cast to integer
rand_trials = 200
val_avg = np.zeros(len(num_neurons))
val_polished_avg = np.zeros(len(num_neurons))
train_avg = np.zeros(len(num_neurons))
train_polished_avg = np.zeros(len(num_neurons))
#standard deviation arrays
val_std = np.zeros(len(num_neurons))
val_polished_std = np.zeros(len(num_neurons))
train_std = np.zeros(len(num_neurons))
train_polished_std = np.zeros(len(num_neurons))
for rand_iter in range(rand_trials):
    print(f"Random iteration {rand_iter+1}/{rand_trials}")
    val = []
    val_polished = []
    train = []
    train_polished = []
    for neurons in num_neurons:
        train_losses, val_losses, net = train_nonconvex(X_tr, y_tr, X_ts, y_ts, num_epochs=50, batch_size=batch_size, learning_rate=0.01, weight_decay = 0.01, hidden_neurons = neurons, verbose=False)
        train_losses_polished, val_losses_polished, net  = polish_two_layer(X_tr, y_tr, X_ts, y_ts, net, num_epochs=50, weight_decay = 0.01, learning_rate=0.01, batch_size=batch_size, verbose=False)
        val.append(val_losses[-1])
        val_polished.append(val_losses_polished[-1])
        train.append(train_losses[-1])
        train_polished.append(train_losses_polished[-1])
    val_avg += np.array(val)
    val_polished_avg += np.array(val_polished)
    val_std += np.array(val)**2
    val_polished_std += np.array(val_polished)**2
    train_avg += np.array(train)
    train_polished_avg += np.array(train_polished)
    train_std += np.array(train)**2
    train_polished_std += np.array(train_polished)**2


val_avg /= rand_trials
val_polished_avg /= rand_trials
val_std = np.sqrt(val_std/rand_trials - val_avg**2)
val_polished_std = np.sqrt(val_polished_std/rand_trials - val_polished_avg**2)
train_avg /= rand_trials
train_polished_avg /= rand_trials
train_std = np.sqrt(train_std/rand_trials - train_avg**2)
train_polished_std = np.sqrt(train_polished_std/rand_trials - train_polished_avg**2)
print(val_avg)
print(val_polished_avg)
print(train_avg)
print(train_polished_avg)
#save val and val_polished and val_std and val_polished_std to pickle files with bath size in the name
import pickle
with open(f'val_avg_{batch_size}.pkl', 'wb') as f:
    pickle.dump(val_avg, f)
with open(f'val_polished_avg_{batch_size}.pkl', 'wb') as f:
    pickle.dump(val_polished_avg, f)
with open(f'val_std_{batch_size}.pkl', 'wb') as f:
    pickle.dump(val_std, f)
with open(f'val_polished_std_{batch_size}.pkl', 'wb') as f:
    pickle.dump(val_polished_std, f)
with open(f'train_avg_{batch_size}.pkl', 'wb') as f:
    pickle.dump(train_avg, f)
with open(f'train_polished_avg_{batch_size}.pkl', 'wb') as f:
    pickle.dump(train_polished_avg, f)
# close pickle files
f.close()
plt.plot(num_neurons,val_avg, 'b--')
plt.plot(num_neurons,val_polished_avg, 'r--')
#plot train losses in dashed lines
plt.plot(num_neurons,train_avg, 'b')
plt.plot(num_neurons,train_polished_avg, 'r')
#legend
plt.legend(['Validation','Validation polished','Train','Train polished'])
# add standard deviation bars to the plots
plt.errorbar(num_neurons, val_avg, yerr=val_std, fmt='o', color='b')
plt.errorbar(num_neurons, val_polished_avg, yerr=val_polished_std, fmt='o', color='r')
plt.errorbar(num_neurons, train_avg, yerr=train_std, fmt='o', color='b')
plt.errorbar(num_neurons, train_polished_avg, yerr=train_polished_std, fmt='o', color='r')
#add labels
plt.xlabel('Number of neurons')
plt.ylabel('Train/validation MSE')
plt.show()
#save to pdf file add date and time to the file name and concatenate batch_size
from datetime import datetime
now = datetime.now()
dt_string = now.strftime("%d%m%Y%H%M%S")
#concatenate batch_size


plt.savefig(f'bsfixed_adamw_100trials_wd1e-2_50_50_epochs_polishing_boston_housing_{dt_string}_{batch_size}.pdf')

