import numpy as np
import torch
import math
from tqdm import tqdm
from torch import amp
from regression.regression_utils import load_pkl, save_pkl
import os
from regression.block_kfold import DatasetBlockKfold
from regression.numpy_dataset import RegressionNumpyDataset
from regression.losses import r2_loss
import json


def block_gpu_multiply(X, Y, X_block_size, Y_block_size):
    num_X_blocks = math.ceil(X.shape[0]/X_block_size)
    num_Y_blocks = math.ceil(Y.shape[1]/Y_block_size)
    out = np.empty((X.shape[0], Y.shape[1]), dtype=np.float32)
    for x_block in tqdm(range(num_X_blocks)):
        x_block_start_index = x_block*X_block_size
        if x_block_start_index + X_block_size < X.shape[0]:
            x_block_end_index = x_block_start_index + X_block_size
        else:
            x_block_end_index = X.shape[0]
        X_block_cpu = X[x_block_start_index:x_block_end_index,:]
        X_block_gpu = torch.from_numpy(X_block_cpu).cuda().to(torch.float32)
        for y_block in range(num_Y_blocks):
            y_block_start_index = y_block*Y_block_size
            if y_block_start_index + Y_block_size <  Y.shape[1]:
                y_block_end_index = y_block_start_index + Y_block_size 
            else:
                y_block_end_index = Y.shape[1]
            Y_block_cpu = Y[:,y_block_start_index:y_block_end_index]
            Y_block_gpu = torch.from_numpy(Y_block_cpu).cuda().to(torch.float32)
            out[x_block_start_index:x_block_end_index,y_block_start_index:y_block_end_index] = torch.matmul(X_block_gpu, Y_block_gpu).cpu().numpy()
    return out

def cholesky_solve(A, b):
    L = torch.linalg.cholesky(A)
    W = torch.cholesky_solve(b, L)
    return W

class FullLinearModel():
    def __init__(self, weights, bias):
        self.weights = weights
        self.bias = bias
    
    def np_predict(self, X, block_size = 1000):
        return block_gpu_multiply(X, self.weights, block_size, block_size) + self.bias
        #return np.matmul(X, self.weights) + self.bias
    
def ridge_regression(X, Y, ridges = [1.0, 0.1, 0.001], block_size = 500, mean_mse_loss = True):
    #efficient implementation that minimizes gpu memory used. Uses cpu to reduce the large arrays
    #X = np.concat(Xs, axis=0).astype(np.float16, copy=False)
    #Y = np.concat(Ys, axis=0)
    if mean_mse_loss:
        scaling = len(X)
    else:
        scaling = 1.0
    print("Initializing Regression")
    #keep off of gpu for memory sake
    X_tX = block_gpu_multiply(X.T, X, block_size, block_size)
    ridge_identity = np.eye(X_tX.shape[0], dtype=np.float32)
    #constant, ok to move into memory
    b = torch.from_numpy(block_gpu_multiply(X.T, Y, block_size, block_size)).cuda().to(torch.float32)
    print("Running Ridges")
    Ws = []
    for ridge_index in tqdm(range(len(ridges)), colour="cyan", desc = "Ridge Fits"):
        A = torch.from_numpy(X_tX + scaling*ridges[ridge_index]*ridge_identity).cuda().to(torch.float32)
        W = cholesky_solve(A, b)
        Ws.append(W.cpu().numpy())
    #save_pkl(zip(ridges,Ws), save_loc)
    return Ws

def ridge_regression_ridge_per_channel(X, Y, ridges, ridge_index_per_channel, block_size = 500, mean_mse_loss=True):
    if mean_mse_loss:
        scaling = len(X)
    else:
        scaling = 1.0
    num_runs = ridge_index_per_channel.shape[0]
    # do I want ridge index per channel to be matrix of batch or just one? Matrix I think
    X_tX = block_gpu_multiply(X.T, X, block_size, block_size)
    ridge_identity = np.eye(X_tX.shape[0], dtype=np.float32)
    #constant, ok to move into memory
    b = torch.from_numpy(block_gpu_multiply(X.T, Y, block_size, block_size)).cuda().to(torch.float32)
    print("Running Ridges")
    Ws = [np.empty((X.shape[1], Y.shape[1])) for _ in range(num_runs)]
    #check if all ridge indices are valid before running this
    for ridge_index in tqdm(range(len(ridges))):
        for run_index in range(ridge_index_per_channel.shape[0]):
            if ridge_index in ridge_index_per_channel:
                channel_selection_indices = (ridge_index_per_channel[run_index] == ridge_index)
                A = torch.from_numpy(X_tX + scaling*ridges[ridge_index]*ridge_identity).cuda().to(torch.float32)
                W_selected = cholesky_solve(A, b[:,channel_selection_indices]).cpu().numpy()
                Ws[run_index][:,channel_selection_indices] = W_selected
    return Ws
                
class RidgeRegression():
    def __init__(self, train_dataset, save_loc = "./runs/ridge_regression/run",
                 block_size = 5000, validation_index = None, mean_mse_loss = True):
        self.train_dataset = train_dataset
        X_torch, Y_torch = self.train_dataset[np.arange(0, len(self.train_dataset))]
        X = X_torch.numpy().astype(np.float16, copy=False)
        self.X_with_bias = np.hstack([X, np.ones((X.shape[0], 1), dtype=np.float16)])
        self.Y = Y_torch.numpy()
        self.save_loc = save_loc
        self.block_size = block_size
        self.validation_index = validation_index
        self.mean_mse_loss = mean_mse_loss
        os.makedirs(save_loc, exist_ok=True)
        
    def fit(self, ridges = [0.1, 1.0, 10.0]):
        if not self.validation_index is None: 
            save_loc = lambda ridge_index: self.save_loc + f"/ridge_{ridge_index}_fold_{self.validation_index}.pkl"
        else:
            save_loc = lambda ridge_index: self.save_loc + f"/ridge_{ridge_index}.pkl"
        
        models = []
        for ridge_index in range(len(ridges)):
            if os.path.exists(save_loc(ridge_index)):
                models.append(load_pkl(save_loc(ridge_index)))
        if len(models) == len(ridges):
            print("Model Fit Found!")
            self.fit_models = models
            return
            
        Ws = ridge_regression(self.X_with_bias, self.Y, ridges, self.block_size, self.mean_mse_loss)
        weights = [W[:-1,:] for W in Ws]
        biases = [W[-1,:] for W in Ws]
        self.fit_models = []
          
        for ridge_index, (weight, bias) in enumerate(zip(weights, biases)):
            model = FullLinearModel(weight, bias)
            self.fit_models.append(model)
            save_pkl(model, save_loc(ridge_index))
    
    def predict(self, validation_dataset, ridge_index):
        X_torch, Y_torch = validation_dataset[np.arange(0, len(validation_dataset))]
        X, Y = X_torch.numpy().astype(np.float16), Y_torch.numpy()
        model:FullLinearModel = self.fit_models[ridge_index]
        Y_predicted = model.np_predict(X)
        return Y_predicted, Y

def cross_val_ridge_regression(dataset, k_folds, channels = 306, ridges = [0.1, 1.0, 10.0, 100.0],
                               model_save_loc = f"./runs/ridge_regression", 
                               block_size = 1000, mean_mse_loss = True):
    dataset_folds = DatasetBlockKfold(dataset.dataset_lens, n_splits = k_folds, block_shuffle = True)
    ridge_losses = np.empty((len(ridges), channels))
    for fold_index, (train_indices, validation_indices) in enumerate(dataset_folds.split()):
        train_dataset = dataset.subset(train_indices)
        validation_dataset = dataset.subset(validation_indices)
        optimization = RidgeRegression(train_dataset, model_save_loc, block_size = block_size, validation_index = fold_index, mean_mse_loss=mean_mse_loss)
        optimization.fit(ridges)
        for ridge_index in range(len(ridges)):
            predicted, validation = optimization.predict(validation_dataset, ridge_index)
            r2 = r2_loss(predicted, validation)
            ridge_losses[ridge_index,:] += r2
        
    torch.cuda.empty_cache()
    print("fitting final")
    best_ridge_index_per_channel = np.argmax(ridge_losses, axis=0)
    final_ridge_params = np.array(ridges)[best_ridge_index_per_channel]
    all_X_torch, all_Y_torch = dataset[np.arange(len(dataset))]
    all_X, all_Y = all_X_torch.numpy().astype(np.float16, copy=False), all_Y_torch.numpy()
    final_W = ridge_regression_ridge_per_channel(all_X, all_Y, ridges, best_ridge_index_per_channel[None, :], block_size=1000, mean_mse_loss=mean_mse_loss)[0]
    final_weights = final_W[:-1,:]
    final_bias = final_W[-1,]
    final_model = FullLinearModel(final_weights, final_bias)
    
    final_model_save_loc = model_save_loc + "/" +"final_model.pkl"
    save_pkl(final_model, final_model_save_loc)
    
    with open(model_save_loc + "/final_model_ridges.json", "w+") as f:
        json.dump({"ridge":final_ridge_params}, f)
    
    return final_model, final_ridge_params
        
    
            