# -*- coding: utf-8 -*-
"""
Created on Sun Jun 30 15:49:02 2024

@author: User
"""
import numpy as np 
from numpy import linalg as LA
import torch
torch.manual_seed(42)
np.random.seed(42)
import random
random.seed(42)
# def global_normalize(data_orig, C):
#     # Reshape data as before
#     data = np.reshape(data_orig, (data_orig.shape[0], int(data_orig.size / data_orig.shape[0])))

#     # Subtract the mean
#     means = np.mean(data, axis=0)
#     data = np.abs(data - means)

#     # Compute the normalization factor
#     # norm = np.sqrt(np.mean(np.sum(data ** 2, axis=1)))
#     # linalg.norm(x, ord=None, 
#     norm = np.mean(np.max(data,1))
#     print(norm)

#     # Normalize the data
#     data = data_orig * C / norm

#     return data



def global_normalize(data_orig, C, p=2,dim_correction= False):
    # Reshape data as before
    data = np.reshape(data_orig, (data_orig.shape[0], int(data_orig.size / data_orig.shape[0])))
    
    # Subtract the mean
    means = np.mean(data, axis=0)
    data = np.abs(data - means)
    
    # Compute the normalization factor
    # norm = np.sqrt(np.mean(np.sum(data ** 2, axis=1)))
    # linalg.norm(x, ord=None, 
    if p ==2:
        norm = np.sqrt(np.mean(np.sum(data ** 2, axis=1)))
    elif p == np.inf:
        norm = np.mean(torch.max(torch.from_numpy(data),1)[0].numpy())
    
    if dim_correction == True:
        norm = norm/ np.sqrt(data.shape[1])
        # norm = norm
    # print(norm)
    # Normalize the data
    
    data = C*(data_orig-means) / norm
    
    # print(np.mean(data**2))
    
    return data


# def local_normalize(data,C):
#     data = np.reshape(data, (data.shape[0],int(data.size/data.shape[0])))

#     means =  np.mean(data, axis=0) # find the mean for each dimension 
#     data = data - means # data - means for each dimension

#     norm = np.tile(np.sqrt(np.mean(data ** 2 ,axis=0)),(data.shape[0],1))
# #     norm =  np.sqrt(np.mean(np.sum(sqz,axis=1)))
#     normalized_data = C*data / (norm+(0.0000001))

#     return normalized_data


def local_normalize(data_orig,C,p=2):
    data = np.reshape(data_orig, (data_orig.shape[0],int(data_orig.size/data_orig.shape[0])))
    
    means =  np.mean(data, axis=0) # find the mean for each dimension 
    data = np.abs(data - means) # data - means for each dimension
    if p == 2:
        norm = np.tile(np.sqrt(np.mean(data ** 2 ,axis=0)),(data.shape[0],1))
    elif p == np.inf:
        norm = np.tile((np.mean(np.abs(data) ,axis=0)),(data.shape[0],1))
#     norm =  np.sqrt(np.mean(np.sum(sqz,axis=1)))

    normalized_data = C*(data_orig-means) / (norm+(0.0000001))
    
    return normalized_data


def get_data_norm(data):
    
    data = np.reshape(data, (data.shape[0], int(data.size / data.shape[0])))
    
    # Subtract the mean
    means = np.mean(data, axis=0)
    data = data - means
    
    # Compute the normalization factor
    norm = np.sqrt(np.mean(np.sum(data ** 2, axis=1)))
    
    return norm 




