import os
import glob
import numpy as np
import itertools
from collections import defaultdict



def main():
    params = {
    'mode': ['uniform',],
    'd': [200],
    'train_scale': [1,5,10],
    'l': [3,4,8],
    'k': range(4,9,4),
    'p': [0.5],
    'lam': [0.5],
    'mu': [0],
    'sigma': [1],
    'alpha': [0.1, 0.5, 1, 2, 5],
    'beta': [0.1, 0.5, 1, 2, 5],
    'random_seed': range(42,42+3,1),
    }
    
    mode = params['mode'][0]

    if mode == 'beta':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        del params['mu'], params['sigma']
    elif mode == 'normal' or mode == 'mix_normal' or mode == 'truncated_normal':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        del params['alpha'], params['beta']
    elif mode == 'binary' or mode == 'uniform':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_coverages = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_freqs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_success_rates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_test_nums = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        del params['mu'], params['sigma'], params['alpha'], params['beta']

    param_names = list(params.keys())
    param_values = list(params.values())
    param_combinations = [dict(zip(param_names, v)) for v in itertools.product(*param_values)]
    for param in param_combinations:
        if param['mode'] == 'beta':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['alpha']),str(param['beta']),str(param['random_seed']))
        elif param['mode'] == 'binary' or param['mode'] == 'uniform':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['random_seed']))
        elif param['mode'] == 'normal' or param['mode'] == 'mix_normal' or param['mode'] == 'truncated_normal':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['mu']), str(param['sigma']), str(param['random_seed']))
        filepath = os.path.join(result_dir, 'loss.txt')
        
        if(os.path.exists(filepath)):
            with open(filepath, 'r') as f:
                lines = f.readlines()
                knn_line = lines[0].split(', ')
                knn_div_line = lines[1].split(', ')
                
                knn_test_num = float(knn_line[4].split(': ')[1])
                knn_loss = float(knn_line[0].split(': ')[1]) * knn_test_num
                knn_coverage = float(knn_line[1].split(': ')[1]) * knn_test_num
                knn_freq = float(knn_line[2].split(': ')[1]) * knn_test_num
                knn_success_rate = float(knn_line[3].split(': ')[1]) * knn_test_num
                
                knn_div_test_num = float(knn_div_line[4].split(': ')[1])
                knn_div_loss = float(knn_div_line[0].split(': ')[1]) * knn_div_test_num
                knn_div_coverage = float(knn_div_line[1].split(': ')[1]) * knn_div_test_num
                knn_div_freq = float(knn_div_line[2].split(': ')[1]) * knn_div_test_num
                knn_div_success_rate = float(knn_div_line[3].split(': ')[1]) * knn_div_test_num
                
                
                if param['mode'] == 'binary' or param['mode'] == 'uniform':
                    knn_losses[param['train_scale']][param['l']][param['k']].append(knn_loss)
                    knn_div_losses[param['train_scale']][param['l']][param['k']].append(knn_div_loss)
                    knn_coverages[param['train_scale']][param['l']][param['k']].append(knn_coverage)
                    knn_div_coverages[param['train_scale']][param['l']][param['k']].append(knn_div_coverage)
                    knn_freqs[param['train_scale']][param['l']][param['k']].append(knn_freq)
                    knn_div_freqs[param['train_scale']][param['l']][param['k']].append(knn_div_freq)
                    knn_success_rates[param['train_scale']][param['l']][param['k']].append(knn_success_rate)
                    knn_div_success_rates[param['train_scale']][param['l']][param['k']].append(knn_div_success_rate)
                    knn_test_nums[param['train_scale']][param['l']][param['k']].append(knn_test_num)
                    knn_div_test_nums[param['train_scale']][param['l']][param['k']].append(knn_div_test_num)
                elif param['mode'] == 'normal' or param['mode'] == 'mix_normal' or param['mode'] == 'truncated_normal':
                    knn_losses[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_loss)
                    knn_div_losses[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_loss)
                    knn_coverages[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_coverage)
                    knn_div_coverages[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_coverage)
                    knn_freqs[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_freq)
                    knn_div_freqs[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_freq)
                    knn_success_rates[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_success_rate)
                    knn_div_success_rates[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_success_rate)
                    knn_test_nums[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_test_num)
                    knn_div_test_nums[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_test_num)
        else:
            print(f"File does not exist: {filepath}")
            
        
    for train_scale in knn_losses:
        print(f"train_scale={train_scale}")
        for l in knn_losses[train_scale]:
            print(f"l={l}")
            for k in knn_losses[train_scale][l]:
                print(f"k={k}")
                if (mode == 'normal' or mode == 'mix_normal' or mode == 'truncated_normal'):
                    for sigma in knn_losses[train_scale][l][k]:
                        print(f"sigma={sigma}")
                        print(f"KNN - Loss: {np.mean(knn_losses[train_scale][l][k][sigma]):.4f}, Coverage: {np.mean(knn_coverages[train_scale][l][k][sigma]):.4f}, Frequency: {np.mean(knn_freqs[train_scale][l][k][sigma]):.4f}, Success Rate: {np.mean(knn_success_rates[train_scale][l][k][sigma]):.4f}, Test Num: {np.mean(knn_test_nums[train_scale][l][k][sigma]):.1f}")
                        print(f"KNN_DIV - Loss: {np.mean(knn_div_losses[train_scale][l][k][sigma]):.4f}, Coverage: {np.mean(knn_div_coverages[train_scale][l][k][sigma]):.4f}, Frequency: {np.mean(knn_div_freqs[train_scale][l][k][sigma]):.4f}, Success Rate: {np.mean(knn_div_success_rates[train_scale][l][k][sigma]):.4f}, Test Num: {np.mean(knn_div_test_nums[train_scale][l][k][sigma]):.1f}")
                elif mode == 'binary' or mode == 'uniform':
                    print(f"TopK-Div - Loss: {np.sum(knn_div_losses[train_scale][l][k]) / np.sum(knn_div_test_nums[train_scale][l][k]):.2f}, Coverage: {np.sum(knn_div_coverages[train_scale][l][k]) / np.sum(knn_div_test_nums[train_scale][l][k]):.2f}")
if __name__ == "__main__":
    main()