import numpy as np

from regression.sgd_regression import SGDRegression, LinearModel
from regression.block_kfold import DatasetBlockKfold, ConcatedBlockKfold
from regression.numpy_dataset import RegressionNumpyDataset
from torch.utils.data import Subset
from src.helpers import load_pkl, save_pkl
from tqdm.auto import tqdm
import os
from regression.sgd_regression import StandardizeData

#change this to just take in a dataset
def R2(Pred, Real):
    SSres = np.mean((Real - Pred) ** 2, 0)
    SStot = np.var(Real, 0)
    return np.nan_to_num(1 - SSres / SStot)

def fit_ridge_weights(X_train, Y_train, regularization, batch_size, epochs, lr, device):
    linear_model = LinearModel(X_train.shape[1], Y_train.shape[1], use_bias=False)
    model = SGDRegression(linear_model, regularization, device=device)
    model.fit(X_train, Y_train, epochs, lr, batch_size)
    return model.model.weights.detach().cpu().numpy()

def ridge_error_over_regularizations(X_train, Y_train, X_validation,
                                      Y_validation, batch_size, epochs, lr,
                                      regularizations=np.array([0.1, 1, 10, 100, 1000]),
                                      label="Ridge Fit",
                                      device="cuda",
                                      log_folder = None):
    num_regularizations = len(regularizations)
    num_predicted_channels = Y_train.shape[1]
    R2s = np.zeros((num_regularizations, num_predicted_channels))
    for i, regularization in tqdm(enumerate(regularizations), colour="green",
                                  total = len(regularizations), position=0, desc=label):
        weights = fit_ridge_weights(X_train, Y_train, regularization, batch_size, epochs, lr, device=device)
        R2s[i] = R2(np.dot(X_validation, weights.T), Y_validation)
    return R2s

def regularization_cross_val_search(regression_dataset:RegressionNumpyDataset, folds, regularizations = [0.001, 0.01, 0.1],
                                    batch_size=100, epochs=1000, lr=1e-4, device="cuda", log_loc = None):
    #add ability to drop zero predictions out, should speed up per epoch by a good margin
    '''
    X_train - [shape(n,features)]
    Y_train - [shape(n,channels)]
    '''
    cross_val_log_folder_func = lambda log_loc, validation_index: f"{log_loc}/val_split_{validation_index}"
    
    num_regularizations = len(regularizations)
    num_regression_features, num_predicted_channels = regression_dataset[0][0].shape

    crossval_regularization_mean_R2s = np.zeros((num_regularizations, num_predicted_channels))
    kfold = ConcatedBlockKfold(regression_dataset.cumulative_indices[1:], folds)
    #finds the error for each channel over the crossval sets
    for (split_index, (fold_train_indices, fold_validation_indices)) in tqdm(enumerate(kfold.split()),
                                                               total=kfold.n_splits,
                                                                colour="cyan", desc="Cross Val: ", position=1):
        train_dataset = Subset(regression_dataset, fold_train_indices)
        validation_dataset = Subset(regression_dataset, fold_validation_indices)
        
        fold_R2s = ridge_error_over_regularizations(train_dataset, validation_dataset, 
                                                    batch_size, epochs, lr, regularizations, label= f"Split {split_index}",
                                                    device=device, log_folder = cross_val_log_folder_func)
        crossval_regularization_mean_R2s += fold_R2s
    
    channel_best_regularization_index = np.argmax(crossval_regularization_mean_R2s, axis=0)
    regularization_indices_to_batch = np.unique(channel_best_regularization_index)

    out_weights = np.zeros((num_regression_features, num_predicted_channels))
    for regularization_index in regularization_indices_to_batch:
        regularization = regularizations[regularization_index]
        selected_channels = (channel_best_regularization_index == regularization_index)
        out_weights[:,selected_channels] = fit_ridge_weights(regression_dataset, regularization, batch_size, epochs, lr, device=device).T

    channel_regularization = np.array([regularizations[i] for i in channel_best_regularization_index])
    return out_weights, channel_regularization

class CrossValidationRegularizationSearch():
    def __init__(self, k_folds, regression_dataset_lens, regularizations, train_stories, delay_dim, embedding_dim,
                 channel_dim, save_loc = "./runs/ridge_test", onset_bias = True, num_workers = 96, max_epochs = 10,lr = 5*1e-3, batch_size = 30000):
        
        self.train_stories = train_stories
        self.save_loc = save_loc
        self.k_folds = k_folds
        self.dataset_lens = regression_dataset_lens
        self.regularizations = np.array(regularizations)
        kfold = DatasetBlockKfold(self.dataset_lens, k_folds, block_shuffle=True)
        self.fold_train_indices_list = []
        self.fold_validation_indices_list = []
        self.delay_dim = delay_dim
        self.embedding_dim = embedding_dim
        self.channel_dim = channel_dim
        for (fold_train_indices, fold_validation_indices) in kfold.split():
            self.fold_train_indices_list.append(fold_train_indices)
            self.fold_validation_indices_list.append(fold_validation_indices)
        self.onset_bias = onset_bias
        self.num_workers = num_workers
        self.max_epochs = max_epochs
        self.lr = lr
        self.batch_size = batch_size
        

        os.makedirs(save_loc, exist_ok=True)
        save_pkl(self.save_loc + "/cross_val_model.pkl", self)
        
    def fold_loc(self, regularization_index, cross_val_index):
        return self.save_loc + f"/regularization_{regularization_index}_split_{cross_val_index}"
    
    def fit_fold(self, regularization_index, cross_val_index, regression_dataset:RegressionNumpyDataset, train_device = "cuda"):
        fold_save_loc = self.fold_loc(regularization_index, cross_val_index)
        os.makedirs(fold_save_loc, exist_ok=True)
        
        #standardizer = StandardizeData(batch_size, num_workers=num_workers, device="cuda")
        #mean, std = standardizer.fit(regression_dataset)
        
        train_dataset = regression_dataset.subset(self.fold_train_indices_list[cross_val_index])
        validation_dataset = regression_dataset.subset(self.fold_validation_indices_list[cross_val_index])
        linear_model = LinearModel(self.delay_dim,self.embedding_dim,self.channel_dim, use_bias=True, mean = None, std = None,onset_bias=self.onset_bias)
        optimization = SGDRegression(linear_model, self.regularizations[regularization_index], device=train_device)
        optimization.fit(train_dataset, self.max_epochs, self.lr, self.batch_size, num_workers = self.num_workers, model_save_loc=fold_save_loc)
        Y_hat, Y = optimization.predict(validation_dataset, self.batch_size)
        loss = R2(Y_hat, Y)
        save_pkl(fold_save_loc + "/validation_R2.pkl", loss)
        return loss
    
    def collect_and_fit(self, regression_dataset, train_device="cuda"):
        final_model_save_loc = self.save_loc + "/final_model"
        validation_total_r2 = np.zeros((self.channel_dim, len(self.regularizations)))
        for regularization_index in range(len(self.regularizations)):
            for cross_val_index in range(self.k_folds):
                validation_r2 = load_pkl(self.fold_loc(regularization_index, cross_val_index) + "/validation_R2.pkl")
                validation_total_r2[:, regularization_index] += validation_r2
        best_regularization_indices = np.argmax(validation_total_r2, axis=1)
        best_regularizations = self.regularizations[best_regularization_indices]
        #standardizer = StandardizeData(batch_size, num_workers=num_workers, device="cuda")
        #mean, std = standardizer.fit(regression_dataset)
        linear_model = LinearModel(self.delay_dim,self.embedding_dim,self.channel_dim, use_bias=True, mean = None, std = None, onset_bias=self.onset_bias)
        optimization = SGDRegression(linear_model, best_regularizations, device=train_device)
        optimization.fit(regression_dataset, self.max_epochs, self.lr, self.batch_size, num_workers = self.num_workers, model_save_loc=final_model_save_loc)
        save_pkl(self.save_loc + "/crossval_regularizations.pkl", best_regularizations)
        
    def find_remaining_runs(self):
        out = []
        for regularization_index in range(len(self.regularizations)):
            for cross_val_index in range(self.k_folds):
                if not os.path.exists(self.fold_loc(regularization_index, cross_val_index) + "/validation_R2.pkl"):
                    out.append((regularization_index, cross_val_index))
        return out
        
    def run_all(self, regression_dataset:RegressionNumpyDataset, train_device = "cuda", 
                 max_epochs: int = 1000, lr: float = 0.001, batch_size: int = 3000, num_workers: int = 0):
        for regularization_index in range(len(self.regularizations)):
            for cross_val_index in range(self.k_folds):
                self.fit_fold(regularization_index, cross_val_index, regression_dataset, train_device, max_epochs, lr, batch_size, num_workers)
        self.collect_and_fit(regression_dataset, train_device, max_epochs, lr, batch_size, num_workers)