import numpy as np
import itertools
import fire
from tqdm import tqdm
import os
from multiprocessing import Pool, cpu_count
import time
import psutil
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error

def find_best_alpha(X_train, y_train, alphas=None, n_splits=5):
    if alphas is None:
        alphas = np.logspace(-3, 3, 7)
    
    kf = KFold(n_splits=n_splits, shuffle=True)
    mean_errors = []
    
    for alpha in alphas:
        fold_errors = []
        for train_idx, val_idx in kf.split(X_train):
            X_fold_train, X_fold_val = X_train[train_idx], X_train[val_idx]
            y_fold_train, y_fold_val = y_train[train_idx], y_train[val_idx]
            
            theta = ridge_solution(X_fold_train, y_fold_train, alpha)
            
            y_pred = X_fold_val @ theta
            
            fold_errors.append(mean_squared_error(y_fold_val, y_pred))
        
        mean_errors.append(np.mean(fold_errors))
    
    best_alpha = alphas[np.argmin(mean_errors)]
    return best_alpha

def generate_x_expc_size():
    pass

def generate_x_equal_size(random_seed=0, d=100, l=2, p=0.5, set_size=100, train_set=None):
    half = d // 2
    results = set()
    while len(results) < set_size:
        front_count = np.random.binomial(l, p)
        back_count = l - front_count
        front_selection = tuple(sorted(np.random.choice(
            range(half), 
            size=front_count, 
            replace=False
        )))
        back_selection = tuple(sorted(np.random.choice(
            range(half, d), 
            size=back_count, 
            replace=False
        )))
        combination = front_selection + back_selection
        if combination not in results and combination not in train_set:
            results.add(combination)
    return list(results)

def generate_theta(random_seed=None, d=100, mu=0.0, sigma=1.0, mode="normal", alpha=1, beta=1):
    if mode == "normal":
        theta = np.random.normal(mu, sigma, d)
    elif mode == "mix_normal":
        weights = [0.5, 0.5]
        means = [mu, mu+1]
        components = np.random.choice(len(weights), size=d, p=weights)
        theta = np.random.normal(loc=[means[c] for c in components], scale=sigma)
    elif mode == "uniform":
        theta = np.random.uniform(0, 1, d)
    elif mode == "binary":
        theta = np.random.choice([0, 1], size=d)
    elif mode == "beta":
        theta = np.random.beta(alpha, beta, d)
    return theta

def compute_y(X, theta, noise_std=0.1):
    noise = np.random.normal(0, noise_std, size=X.shape[0])
    return X @ theta + noise

def select_knn(X_train, x_i, k):
    scores = X_train @ x_i
    top_indices = np.argsort(scores)[::-1][:k]
    return top_indices

def select_knn_diversity(X_train, x_i, k, l, d, lam=0.5):
    N_train = X_train.shape[0]
    train_scores = X_train @ x_i
    knn_scores = train_scores / (d-1)
    
    first_idx = np.argmax(train_scores)
    selected = np.array([first_idx])
    remain_mask = np.ones(N_train, dtype=bool)
    remain_mask[first_idx] = False
    G = X_train @ X_train.T

    for _ in range(k - 1):
        simi_scores = G[:, selected].sum(axis=1)
        diversity_scores = (len(selected) * l - simi_scores) / (len(selected) * l)
        total_score = lam * diversity_scores + (1-lam) * knn_scores
        total_score[~remain_mask] = -float('inf')
        next_indice = np.argmax(total_score)
        selected = np.append(selected, next_indice)
        remain_mask[next_indice] = False
    return selected

def ridge_solution(X_sub, y_sub, alpha=1.0):
    k, d = X_sub.shape
    I = np.eye(d)
    theta_sol = np.linalg.solve(X_sub.T @ X_sub + alpha * I, X_sub.T @ y_sub)
    return theta_sol

def run_experiment(params):
    random_seed, l, d, mu, sigma, k, p, mode, train_scale, lam, alpha, beta, noise_std = params
    
    try:
        np.random.seed(random_seed)
        train_size = d * train_scale
        test_size = d
        
        train_set = set()
        train_set = generate_x_equal_size(random_seed=random_seed, d=d, l=l, p=p, 
                                        set_size=train_size, train_set=train_set)
        test_set = generate_x_equal_size(random_seed=random_seed, d=d, l=l, p=1-p, 
                                       set_size=test_size, train_set=train_set)
        
        X_train = np.zeros((train_size, d))
        X_test = np.zeros((test_size, d))
        X_train[np.arange(train_size)[:, None], list(train_set)] = 1
        X_test[np.arange(test_size)[:, None], list(test_set)] = 1
        
        del train_set, test_set
        
        X_train = np.random.permutation(X_train)
        X_test = np.random.permutation(X_test)
        
        theta = generate_theta(random_seed=random_seed, d=d, mu=mu, sigma=sigma, 
                             mode=mode, alpha=alpha, beta=beta)
        Y_train = compute_y(X_train, theta, noise_std=noise_std)
        Y_test = compute_y(X_test, theta, noise_std=noise_std)
        
        loss_knn = 0.0
        for i in range(test_size):
            knn_indices = select_knn(X_train, X_test[i], k)
            X_knn = X_train[knn_indices]
            Y_knn = Y_train[knn_indices]
            best_alpha = find_best_alpha(X_knn, Y_knn)
            theta_knn_i = ridge_solution(X_knn, Y_knn, alpha=best_alpha)
            pred = theta_knn_i @ X_test[i]
            loss_knn += (pred - Y_test[i])**2
        
        loss_knn_diversity = 0.0
        for i in range(test_size):
            knn_diversity_indices = select_knn_diversity(X_train, X_test[i], k, l, d, lam=lam)
            X_knn_diversity = X_train[knn_diversity_indices]
            Y_knn_diversity = Y_train[knn_diversity_indices]
            best_alpha = find_best_alpha(X_knn_diversity, Y_knn_diversity)
            theta_knn_diversity_i = ridge_solution(X_knn_diversity, Y_knn_diversity, alpha=best_alpha)
            pred = theta_knn_diversity_i @ X_test[i]
            loss_knn_diversity += (pred - Y_test[i])**2
        
        result_dir = os.path.join("results", mode, str(random_seed), str(l), str(d), 
                                 str(sigma), str(k), str(p), str(train_scale), str(lam), 
                                 str(alpha), str(beta), str(noise_std))
        os.makedirs(result_dir, exist_ok=True)
        result_file = os.path.join(result_dir, "loss.txt")
        with open(result_file, "w") as f:
            f.write(f"loss_knn: {loss_knn}\nloss_knn_diversity: {loss_knn_diversity}")
        
        return loss_knn, loss_knn_diversity
    
    except Exception as e:
        print(f"Error in experiment with params {params}: {e}")
        return None

def main(random_seed=42, l=2, d=100, mu=0.0, sigma=1.0, k=4, p=0.5, 
         mode="normal", train_scale=10, lam=0.5, parallel=True):
    if parallel:
        param_grid = {
            'random_seed': range(42, 42+3),
            'l': range(4, 9),
            'd': [300],
            'mu': [10],
            'sigma': [1],
            'k': range(4, 9),
            'p': [0.75],
            'mode': ['beta'],
            'train_scale': [1, 5, 10],
            'lam': [0.5],
            'alpha': [0.1, 0.5, 1, 2, 5],
            'beta': [0.1, 0.5, 1, 2, 5],
            'noise_std': [0.1, 0.5, 1.0]
        }
        
        param_combinations = [dict(zip(param_grid.keys(), v)) 
                            for v in itertools.product(*param_grid.values())]
        total_tasks = len(param_combinations)
        print(f"Total tasks: {total_tasks}")
        
        param_tuples = [(p['random_seed'], p['l'], p['d'], p['mu'], p['sigma'], 
                         p['k'], p['p'], p['mode'], p['train_scale'], p['lam'], 
                         p['alpha'], p['beta'], p['noise_std']) 
                        for p in param_combinations]
        
        total_cores = cpu_count()
        cores_to_reserve = 24
        cores_to_use = max(1, total_cores - cores_to_reserve)
        print(f"Total CPU cores: {total_cores}")
        print(f"Cores to use: {cores_to_use}")
        
        try:
            with Pool(processes=cores_to_use) as pool:
                list(tqdm(
                    pool.imap_unordered(run_experiment, param_tuples),
                    total=total_tasks,
                    desc="Running experiments"
                ))
        except Exception as e:
            print(f"Error in multiprocessing: {e}")
    else:
        for params in param_combinations:
            run_experiment(params)

if __name__ == "__main__":
    fire.Fire(main)