import numpy as np
import pickle
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]

save_path = f"{BASE_PATH}/data/equities/equity_dataset_2018.npz"

with np.load(save_path, allow_pickle=True) as data:
    splits = pickle.loads(data['splits'].item())
    stats = pickle.loads(data['stats'].item())
print(f"Dataset loaded from {save_path}")
#print(splits.keys())
print(stats)
train = splits["train"]
val = splits["val"]
print(train.keys())
train_X = train["X"]
train_Y = train["Y"]  # shape (T=2012,N=500)
val_X = val["X"]
val_Y = val["Y"] # shape (T=251, N=500)

print(train_X.shape)
print(train_Y.shape)
print(val_X.shape)
print(val_Y.shape)


import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

# train_Y: shape (T_train, N) = (2012, 500)
# val_Y:   shape (T_val,   N) = (251,  500)

# 1) Fit PCA on train_Y
pca = PCA(n_components=1)
pca.fit(train_Y)

# 2) Transform & reconstruct train_Y
train_Y_pca = pca.transform(train_Y)
train_Y_hat = pca.inverse_transform(train_Y_pca)

# 3) Compute average per-asset R^2 on train set
train_r2_list = []
for i in range(train_Y.shape[1]):
    y_i    = train_Y[:, i]
    y_ihat = train_Y_hat[:, i]
    sse_i  = np.sum((y_i - y_ihat)**2)
    sst_i  = np.sum((y_i - np.mean(y_i))**2)
    r2_i   = 1.0 - (sse_i / sst_i)
    train_r2_list.append(r2_i)

train_r2_avg = np.mean(train_r2_list)
print("Average per-asset R^2 on train:", train_r2_avg)

# --- Now do the same for val set ---
val_Y_pca = pca.transform(val_Y)
val_Y_hat = pca.inverse_transform(val_Y_pca)

val_r2_list = []
for i in range(val_Y.shape[1]):
    y_i    = val_Y[:, i]
    y_ihat = val_Y_hat[:, i]
    sse_i  = np.sum((y_i - y_ihat)**2)
    sst_i  = np.sum((y_i - np.mean(y_i))**2)
    r2_i   = 1.0 - (sse_i / sst_i)
    val_r2_list.append(r2_i)

val_r2_avg = np.mean(val_r2_list)
val_r2_median = np.median(val_r2_list)
val_r2_min = np.min(val_r2_list)
val_r2_max = np.max(val_r2_list)
print("Average per-asset R^2 on val:", val_r2_avg)
print("Median per-asset R^2 on val:", val_r2_median)
print("Min per-asset R^2 on val:", val_r2_min)
print("Max per-asset R^2 on val:", val_r2_max)

