from google.cloud import storage
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests
from scipy import stats
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity as cossim
import os
import torch as t
import gc
import psutil
import random

def flush_sae(model, optimizer, early_stopper, device):
    del model, optimizer, early_stopper
    # Clear GPU cache
    if device == 'cuda':
        t.cuda.empty_cache()
        t.cuda.reset_peak_memory_stats()
    if device == 'mps':
        t.mps.empty_cache()
    gc.collect()

def get_ram():
    mem = psutil.virtual_memory()
    free = mem.available / 1024 ** 3
    total = mem.total / 1024 ** 3
    total_cubes = 24
    free_cubes = int(total_cubes * free / total)
    return f'RAM: {total - free:.2f}/{total:.2f}GB\t RAM:[' + (total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']'


def get_vram():
    free = t.cuda.mem_get_info()[0] / 1024 ** 3
    total = t.cuda.mem_get_info()[1] / 1024 ** 3
    total_cubes = 24
    free_cubes = int(total_cubes * free / total)
    return f'VRAM: {total - free:.2f}/{total:.2f}GB\t VRAM:[' + (
            total_cubes - free_cubes) * '▮' + free_cubes * '▯' + ']'


def set_seed(seed_value=42):
    """Set seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    t.manual_seed(seed_value)
    if t.cuda.is_available():
        t.cuda.manual_seed_all(seed_value)
        # The following two lines are for full reproducibility with CUDA
        t.backends.cudnn.deterministic = True
        t.backends.cudnn.benchmark = False


####################
def my_iqr(x):
    res = x.quantile(0.75) - x.quantile(0.25)
    return res


def q1(x):
    res = x.quantile(0.25)
    return res


def q3(x):
    res = x.quantile(0.75)
    return res


def percentile(n):
    def percentile_(x):
        return x.quantile(n)

    percentile_.__name__ = 'q_{:02.0f}'.format(n * 100)
    return percentile_


##############################
def disable_grad(model_object):
    model_object.eval()
    for p_name, param in model_object.named_parameters():
        param.requires_grad = False
        if param.requires_grad:
            print(p_name)


######################################

# Function to compute cosine similarity ignoring NaNs
def cossim_nan(matrix):
    n = matrix.shape[0]
    similarity_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            vec1 = matrix[i]
            vec2 = matrix[j]
            if not (any(np.isnan(vec1)) or any(np.isnan(vec2))):
                similarity_matrix[i, j] = cossim([vec1, vec2])[0, 1]
            else:
                similarity_matrix[i, j] = np.nan
    return similarity_matrix


########################################
def annotate_correlations(x, y, corr_mat, p_vals_mat):
    idx1 = corr_mat.columns.get_loc(x.name)
    idx2 = corr_mat.columns.get_loc(y.name)
    r = corr_mat.iloc[idx1, idx2]  # Get correlation from the matrix
    p_val = p_vals_mat[idx1, idx2]  # Get corrected p-value
    ax = plt.gca()
    ax.annotate(f'r: {r:.2f}; p: {p_val:.2e}', xy=(0.1, 1.1), xycoords=ax.transAxes)


def get_p_corrected(corr_mat, df_tmp):
    # Compute the p-values for the correlations
    p_values = np.zeros(corr_mat.shape)
    for i in range(len(df_tmp.columns)):
        for j in range(i + 1, len(df_tmp.columns)):
            # print(i,j)
            valid_data = df_tmp[[df_tmp.columns[i], df_tmp.columns[j]]].dropna()
            valid_data = valid_data.apply(pd.to_numeric, errors='coerce').dropna()
            # _, p = stats.pearsonr(df_tmp.iloc[:, i], df_tmp.iloc[:, j])
            # If there are not enough valid data points, assign NaN to p-values
            if len(valid_data) > 1:  # Pearson requires at least 2 data points
                _, p = stats.pearsonr(valid_data.iloc[:, 0], valid_data.iloc[:, 1])
            else:
                p = np.nan
            p_values[i, j] = p
            p_values[j, i] = p  # Symmetric matrix

    # Flatten p-values and apply multiple comparison correction
    p_vals_flat = p_values[np.triu_indices_from(p_values, 1)]
    _, p_vals_corrected, _, _ = multipletests(p_vals_flat, method='bonferroni')

    # Reshape corrected p-values back into matrix form
    p_vals_corrected_mat = np.zeros_like(p_values)
    p_vals_corrected_mat[np.triu_indices_from(p_vals_corrected_mat, 1)] = p_vals_corrected
    p_vals_corrected_mat = p_vals_corrected_mat + p_vals_corrected_mat.T  # Symmetry

    return p_vals_corrected_mat, p_values


############################
def upload_blob_from_memory(bucket_name, contents, destination_blob_name):
    """Uploads a file to the bucket."""""
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)

    blob.upload_from_filename(contents)

    print(f"{destination_blob_name} with contents {contents} uploaded to {bucket_name}.")


def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""""
    # The ID of your GCS bucket
    # bucket_name = "your-bucket-name"

    # The ID of your GCS object
    # source_blob_name = "storage-object-name"

    # The path to which the file should be downloaded
    # destination_file_name = "local/path/to/file"

    storage_client = storage.Client()

    bucket = storage_client.bucket(bucket_name)

    # Construct a client side representation of a blob.
    # Note `Bucket.blob` differs from `Bucket.get_blob` as it doesn't retrieve
    # any content from Google Cloud Storage. As we don't need additional data,
    # using `Bucket.blob` is preferred here.
    blob = bucket.blob(source_blob_name)
    blob.download_to_filename(destination_file_name)

