import os
import glob
import numpy as np
import itertools
from collections import defaultdict
def analyze_parameter_three_loss(param_name, param_values, other_params):
    valid_params = other_params.copy()
    for value in param_values:
        knn_losses = []
        knn_div_losses = []
        knn_div_origin_losses = []
        pattern = "results"
        for p in other_params:
            if p == param_name:
                pattern = os.path.join(pattern, str(value))
            else:
                pattern = os.path.join(pattern, "*")
        pattern = os.path.join(pattern, "loss.txt")
        
        for filepath in glob.glob(pattern):
            parts = filepath.split(os.sep)[1:-1]
            is_valid = True
            for i, p in enumerate(valid_params.keys()):
                if p == param_name:
                    if str(value) != parts[i]:
                        is_valid = False
                        break
                elif parts[i] != "*":
                    if(p == 'mode'):
                        if parts[i] not in valid_params[p]:
                            is_valid = False
                            break
                    else:
                        if float(parts[i]) not in valid_params[p]:
                            is_valid = False
                            break
            if not is_valid:
                continue
            try:
                with open(filepath, 'r') as f:
                    lines = f.readlines()
                    if len(lines) >= 2:
                        knn_loss = float(lines[0].split(': ')[1])
                        knn_div_loss = float(lines[1].split(': ')[1])
                        knn_losses.append(knn_loss)
                        knn_div_losses.append(knn_div_loss)
            except:
                continue
        
        if len(knn_losses) > 0:
            avg_knn = np.mean(knn_losses)
            avg_knn_div = np.mean(knn_div_losses)
            print(f"{param_name}={value}:")
            print(f"  KNN average loss: {avg_knn:.4f}")
            print(f"  KNN_DIV average loss: {avg_knn_div:.4f}")
            min_loss = min(avg_knn, avg_knn_div)
            if(min_loss == avg_knn):
                print(f"*** At {param_name}={value}, KNN performs better on average ***")
            elif(min_loss == avg_knn_div):
                print(f"*** At {param_name}={value}, KNN_DIV performs better on average ***")
            print()
def analyze_parameter_two_loss_alpha_beta(param_name, param_values, other_params):
    valid_params = other_params.copy()
    
    if param_name in ['alpha', 'beta']:
        print(f"\nAnalyzing alpha-beta combinations")
        print("-" * 50)
        
        for alpha in other_params['alpha']:
            for beta in other_params['beta']:
                knn_losses = []
                knn_div_losses = []
                
                pattern = "results"
                for p in other_params:
                    if p == 'alpha':
                        pattern = os.path.join(pattern, str(alpha))
                    elif p == 'beta':
                        pattern = os.path.join(pattern, str(beta))
                    else:
                        pattern = os.path.join(pattern, "*")
                pattern = os.path.join(pattern, "loss.txt")
                
                for filepath in glob.glob(pattern):
                    parts = filepath.split(os.sep)[1:-1]
                    
                    is_valid = True
                    for i, p in enumerate(valid_params.keys()):
                        if p == 'alpha' and str(alpha) != parts[i]:
                            is_valid = False
                            break
                        elif p == 'beta' and str(beta) != parts[i]:
                            is_valid = False
                            break
                        elif parts[i] != "*":
                            if p == 'mode':
                                if parts[i] not in valid_params[p]:
                                    is_valid = False
                                    break
                            else:
                                if float(parts[i]) not in valid_params[p]:
                                    is_valid = False
                                    break
                                    
                    if not is_valid:
                        continue
                        
                    try:
                        with open(filepath, 'r') as f:
                            lines = f.readlines()
                            if len(lines) >= 2:
                                knn_loss = float(lines[0].split(': ')[1])
                                knn_div_loss = float(lines[1].split(': ')[1])
                                knn_losses.append(knn_loss)
                                knn_div_losses.append(knn_div_loss)
                    except:
                        continue
                
                if len(knn_losses) > 0:
                    avg_knn = np.mean(knn_losses)
                    avg_knn_div = np.mean(knn_div_losses)
                    print(f"alpha={alpha}, beta={beta}:")
                    print(f"  KNN average loss: {avg_knn:.4f}")
                    print(f"  KNN_DIV average loss: {avg_knn_div:.4f}")
                    min_loss = min(avg_knn, avg_knn_div)
                    if min_loss == avg_knn:
                        print(f"*** At alpha={alpha}, beta={beta}, KNN performs better on average ***")
                    else:
                        print(f"*** At alpha={alpha}, beta={beta}, KNN_DIV performs better on average ***")
                    print()
        return