import os
import glob
import numpy as np
import itertools
from collections import defaultdict



def main():
    params = {
    'mode': ['uniform',],
    'd': [200],
    'train_scale': [1],
    'l': [3,4,6,8],
    'k': range(4,9,4),
    'p': [0.5],
    'lam': [0.5],
    'mu': [0],
    'sigma': [1],
    'alpha': [0.1, 0.5, 1, 2, 5],
    'beta': [0.1, 0.5, 1, 2, 5],
    'random_seed': range(42,42+3,1),
    }
    
    mode = params['mode'][0]

    if mode == 'beta':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        del params['mu'], params['sigma']
    elif mode == 'normal' or mode == 'mix_normal' or mode == 'truncated_normal':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        del params['alpha'], params['beta']
    elif mode == 'binary' or mode == 'uniform':
        knn_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        knn_div_losses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        del params['mu'], params['sigma'], params['alpha'], params['beta']

    param_names = list(params.keys())
    param_values = list(params.values())
    param_combinations = [dict(zip(param_names, v)) for v in itertools.product(*param_values)]
    
    for param in param_combinations:
        if param['mode'] == 'beta':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['alpha']),str(param['beta']),str(param['random_seed']))
        elif param['mode'] == 'binary' or param['mode'] == 'uniform':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['random_seed']))
        elif param['mode'] == 'normal' or param['mode'] == 'mix_normal' or param['mode'] == 'truncated_normal':
            result_dir = os.path.join("results", param['mode'], str(param['d']),  str(param['train_scale']), str(param['l']), str(param['k']), str(param['p']), str(param['lam']), str(param['mu']), str(param['sigma']), str(param['random_seed']))
        filepath = os.path.join(result_dir, 'loss.txt')
        with open(filepath, 'r') as f:
            lines = f.readlines()
            knn_loss = float(lines[0].split(': ')[1])
            knn_div_loss = float(lines[1].split(': ')[1])
            if param['mode'] == 'binary' or param['mode'] == 'uniform':
                knn_losses[param['train_scale']][param['l']][param['k']].append(knn_loss)
                knn_div_losses[param['train_scale']][param['l']][param['k']].append(knn_div_loss)
            elif param['mode'] == 'normal' or param['mode'] == 'mix_normal' or param['mode'] == 'truncated_normal':
                knn_losses[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_loss)
                knn_div_losses[param['train_scale']][param['l']][param['k']][param['sigma']].append(knn_div_loss)
    
    for train_scale in knn_losses:
        print(f"train_scale={train_scale}")
        print('-'*100)
        for l in knn_losses[train_scale]:
            print(f"l={l}")
            print('-'*100)
            for k in knn_losses[train_scale][l]:
                print(f"k={k}")
                if (mode == 'normal' or mode == 'mix_normal' or mode == 'truncated_normal'):
                    for sigma in knn_losses[train_scale][l][k]:
                        print(f"sigma={sigma}")
                        print(f"KNN loss: {np.mean(knn_losses[train_scale][l][k][sigma])}")
                        print(f"KNN_DIV loss: {np.mean(knn_div_losses[train_scale][l][k][sigma])}")
                elif mode == 'binary' or mode == 'uniform':
                    print(f"KNN loss: {np.mean(knn_losses[train_scale][l][k])}")
                    print(f"KNN_DIV loss: {np.mean(knn_div_losses[train_scale][l][k])}")
            print('-'*50)
        print('-'*100)
if __name__ == "__main__":
    main()