import numpy as np
import itertools
import fire
from tqdm import tqdm
import os
from multiprocessing import Pool, cpu_count
import time
import psutil
from utils.compute_utils import min_norm_solution, compute_y
from utils.select_utils import select_knn, select_knn_diversity
from utils.gen_utils import generate_x_equal_size, generate_theta
from utils.params_utils import params_combinations
from numpy.random import default_rng

def run_experiment(params):
    rng = default_rng(params['random_seed'])
    if params['mode'] == 'binary' or params['mode'] == 'uniform':
        mode, d, train_scale, l, k, p, lam, random_seed = params['mode'], params['d'], params['train_scale'], params['l'], params['k'], params['p'], params['lam'], params['random_seed']
        theta = generate_theta(mode=mode, d=d, random_seed=random_seed, )
    elif params['mode'] == 'normal' or params['mode'] == 'mix_normal' or params['mode'] == 'truncated_normal':
        mode, d, train_scale, l, k, p, lam, mu, sigma, random_seed = params['mode'], params['d'], params['train_scale'], params['l'], params['k'], params['p'], params['lam'], params['mu'], params['sigma'], params['random_seed']
        theta = generate_theta(mode=mode, d=d, random_seed=random_seed, mu=mu, sigma=sigma)
    elif params['mode'] == 'beta':
        mode, d, train_scale, l, k, p, lam, alpha, beta, random_seed = params['mode'], params['d'], params['train_scale'], params['l'], params['k'], params['p'], params['lam'], params['alpha'], params['beta'], params['random_seed']
        theta = generate_theta(mode=mode, d=d, random_seed=random_seed, alpha=alpha, beta=beta)
        raise ValueError(f"Invalid mode: {params['mode']}")

    train_size = d * train_scale
    test_size = d * 2000
    train_set = set()
    
    train_set = generate_x_equal_size(rng, random_seed=random_seed, d=d, l=l, p=p, set_size=train_size, train_set=train_set)
    test_set = generate_x_equal_size(rng, random_seed=random_seed, d=d, l=2*l, p=1-p, set_size=test_size, train_set=train_set)

    X_train = np.zeros((train_size, d))
    X_test = np.zeros((test_size, d))
        
    X_train[np.arange(train_size)[:, None], list(train_set)] = 1
    X_test[np.arange(test_size)[:, None], list(test_set)] = 1

    del train_set
    del test_set
        
    X_train = np.random.permutation(X_train)
    X_test = np.random.permutation(X_test)
    
    Y_train = compute_y(X_train, theta)
    Y_test = compute_y(X_test, theta)
    
    del theta

    loss_knn = 0.0
    total_cov_knn = 0.0
    total_cov_count_knn = 0.0
    total_succ_knn = 0.0
    
    loss_knn_diversity = 0.0
    total_cov_knn_diversity = 0.0
    total_cov_count_knn_diversity = 0.0
    total_succ_knn_diversity = 0.0
    
    total_test_size = 0
    for i in range(test_size):
        knn_indices = select_knn(rng, X_train, X_test[i], k)
        knn_diversity_indices = select_knn_diversity(rng, X_train, X_test[i], k, l, d, lam)
        
        a = np.zeros(d)
        for idx in knn_indices:
            a += X_train[idx]
        
        test_nonzero = X_test[i] != 0
        a_nonzero = a != 0
        overlap = np.logical_and(test_nonzero, a_nonzero)
        cov_knn = np.sum(overlap) / np.sum(test_nonzero)
        cov_count_knn = np.sum(a * X_test[i]) / np.sum(X_test[i])
        succ_knn = 1.0 if cov_knn == 1.0 else 0.0
        
        a = np.zeros(d)
        for idx in knn_diversity_indices:
            a += X_train[idx]
        
        test_nonzero = X_test[i] != 0
        a_nonzero = a != 0
        overlap = np.logical_and(test_nonzero, a_nonzero)
        cov_knn_diversity = np.sum(overlap) / np.sum(test_nonzero)
        cov_count_knn_diversity = np.sum(a * X_test[i]) / np.sum(X_test[i])
        succ_knn_diversity = 1.0 if cov_knn_diversity == 1.0 else 0.0
        if(total_test_size < 100):
            if ((l < 5 and succ_knn == 1.0 and succ_knn_diversity == 1.0) or (l > 5 and succ_knn == 0.0 and succ_knn_diversity == 0.0)):
                total_test_size += 1
                X_knn = X_train[knn_indices]
                Y_knn = Y_train[knn_indices]
                theta_knn_i = min_norm_solution(X_knn, Y_knn)
                pred = theta_knn_i @ X_test[i]
                loss_knn += (pred - Y_test[i])**2
                
                total_cov_knn += cov_knn
                total_cov_count_knn += cov_count_knn
                total_succ_knn += succ_knn

                X_knn_diversity = X_train[knn_diversity_indices]
                Y_knn_diversity = Y_train[knn_diversity_indices]
                theta_knn_diversity_i = min_norm_solution(X_knn_diversity, Y_knn_diversity)
                pred = theta_knn_diversity_i @ X_test[i]
                loss_knn_diversity += (pred - Y_test[i])**2

                total_cov_knn_diversity += cov_knn_diversity
                total_cov_count_knn_diversity += cov_count_knn_diversity
                total_succ_knn_diversity += succ_knn_diversity
            
                del X_knn, Y_knn, theta_knn_i
                del X_knn_diversity, Y_knn_diversity, theta_knn_diversity_i
            if(total_test_size == 100):
                break
    
    avg_cov_knn = total_cov_knn / total_test_size
    avg_cov_count_knn = total_cov_count_knn / total_test_size  
    avg_succ_knn = total_succ_knn / total_test_size 
        
    avg_cov_knn_diversity = total_cov_knn_diversity / total_test_size
    avg_cov_count_knn_diversity = total_cov_count_knn_diversity / total_test_size  
    avg_succ_knn_diversity = total_succ_knn_diversity / total_test_size
        
    del X_train, X_test, Y_train, Y_test
        
    if mode == 'beta':
        result_dir = os.path.join("results", mode, str(d),  str(train_scale), str(l), str(k), str(p), str(lam), str(alpha),str(beta),str(random_seed))
    elif mode == 'binary' or mode == 'uniform':
        result_dir = os.path.join("results", mode, str(d),  str(train_scale), str(l), str(k), str(p), str(lam), str(random_seed))
    elif mode == 'normal' or mode == 'mix_normal' or mode == 'truncated_normal':
        result_dir = os.path.join("results", mode, str(d),  str(train_scale), str(l), str(k), str(p), str(lam), str(mu), str(sigma), str(random_seed))

    os.makedirs(result_dir, exist_ok=True)
    result_file = os.path.join(result_dir, "loss.txt")
    with open(result_file, "w") as f:
        f.write(f"loss_knn: {loss_knn / total_test_size}, coverage: {avg_cov_knn}, point occurrence count: {avg_cov_count_knn}, success rate: {avg_succ_knn}, total test count: {total_test_size}\nloss_knn_diversity: {loss_knn_diversity / total_test_size}, coverage: {avg_cov_knn_diversity}, point occurrence count: {avg_cov_count_knn_diversity}, success rate: {avg_succ_knn_diversity}, total test count: {total_test_size}\n")
    return loss_knn, loss_knn_diversity
    
def main(mode: str):
    param_combinations = params_combinations(mode=mode)

    total_tasks = len(param_combinations)
    print(f"total tasks: {total_tasks}")

    for params in tqdm(param_combinations):
        run_experiment(params)

if __name__ == "__main__":
    fire.Fire(main)
