import multiprocessing
import logging
from filelock import FileLock
import os
import torch
import numpy as np
import random

def normalize(X):
    mu = X.mean(axis=0)
    std = X.std(axis=0) + 1e-8
    return (X - mu) / std

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def select_device(cfg, role_id):
    if cfg['gpu_id'] >= 0:
        return f'cuda:{cfg["gpu_id"]}'
    gpus = cfg['gpus']
    return f'cuda:{gpus[role_id % len(gpus)]}'

def rank(scores, order='ascend'):
    if order == 'ascend':
        return scores.shape[1] - scores.argsort().argsort()
    elif order == 'descend':
        return scores.argsort().argsort() + 1
    else:
        raise ValueError("Invalid order. Use 'ascend' or 'descend'.")

def parallel(N, func, *args):
    gpu_count = 6
    processes = []
    for i in range(N // gpu_count + 1):
        for task_id in range(i * gpu_count + 1, min((i + 1) * gpu_count, N) + 1 ):
            p = multiprocessing.Process(target=func, args=(*args, task_id))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

def setup_logger(log_file):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')

    with FileLock(log_file + ".lock"):
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger

def PCA(data: torch.Tensor, target_dim: int=1, center: bool=True) -> torch.Tensor:
    """ data: (n_data, n_features) --> (n_data, target_dim) """
    if data.dim() != 2:
        raise ValueError("Input data must be a 2D tensor.")
    n, d = data.shape

    if center:
        data = data - data.mean(dim=0, keepdim=True)

    if n > d:
        # C = X^T X / (n - 1)
        C = data.T @ data / (n - 1)
    else:
        # C = X X^T / (n - 1)
        C = data @ data.T / (n - 1)

    eigenvalues, eigenvectors = torch.linalg.eigh(C)
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    eigenvalues = eigenvalues[sorted_indices]
    eigenvectors = eigenvectors[:, sorted_indices]
    principal_components = eigenvectors[:, :target_dim]

    if n <= d:
        principal_components = (data.T @ principal_components)  # (d, d')

        for i in range(target_dim):
            principal_components[:, i] = principal_components[:, i] / torch.norm(principal_components[:, i])

    projected_data = data @ principal_components  # (n, d')

    return projected_data