import os
import numpy as np

def calculate_accuracies(base_path, methods, models, test_datasets, train_datasets, k_values):
    results = {}
             
    for method in methods:
        for model in models:
            for k in k_values:
                id_accs = []
                ood_accs = []
                for test_dataset in test_datasets:
                    file_path = os.path.join(base_path, method, model, test_dataset, str(k), "all-roberta-large-v1.txt")
                    try:
                        with open(file_path, 'r') as f:
                            lines = f.readlines()
                            total_lines = len(lines)
                            for i, line_idx in enumerate(range(1, total_lines+1, 3)):
                                try:
                                    dataset_line = lines[line_idx].strip()
                                    current_dataset = dataset_line.split('train_dataset: ')[1]
                                    acc = float(lines[line_idx + 1].strip().split('accuracy: ')[1])
                                    
                                    if current_dataset == test_dataset:
                                        id_accs.append(acc)
                                    elif current_dataset in train_datasets:
                                        ood_accs.append(acc)
                                except (IndexError, ValueError) as e:
                                    print(f"Warning: Parsing error at line {line_idx} in {file_path}: {e}")
                                    continue
                    except FileNotFoundError:
                        print(f"Warning: File not found - {file_path}")
                        continue
                
                if id_accs:
                    key = f"{method}/{model}/{k}"
                    results[key] = {
                        'id_acc': np.mean(id_accs),
                        'id_std': np.std(id_accs),
                        'id_len': len(id_accs),
                        'id_max': np.max(id_accs),
                        'id_min': np.min(id_accs)
                    }
    
    return results

def test_knn_diversity(models, test_dataset, k_values, results):

    for model in models:
        for k in k_values:
            ran = range(1, 5)
            for i in ran:
                acc_0 = []
                acc_1 = []
                
                method_key = f"knn_diversity_{i/10}/{model}/{k}"
                knn_key = f"knn/{model}/{k}"
                if method_key in results:
                    acc_0.append(results[method_key]['id_acc'])
                if knn_key in results:
                    acc_0.append(results[knn_key]['id_acc'])
            ran = range(5, 10)
            for i in ran:
                method_key = f"knn_diversity_{i/10}/{model}/{k}"
                if method_key in results:
                    acc_1.append(results[method_key]['id_acc'])
            if acc_0 and acc_1:
                print(f"{model} - {test_dataset} - {k}: {((np.mean(acc_0) - np.mean(acc_1)) / np.mean(acc_0) )*100:.2f}\% ")


def main():
    methods = ['knn']
    models = ['gemma-2-2b', 'llama-3.2-3b']
    test_datasets = ['commonsense_qa']
    k_values = [4, 8]  

    if('winogrande_0' not in test_datasets[0]):
        for test_dataset in test_datasets:
            train_dataset_tmp = test_dataset_tmp = [test_dataset]
            results = calculate_accuracies('./results', methods, models, test_dataset_tmp, train_dataset_tmp, k_values)
            print(f"train_dataset: {train_dataset_tmp}, test_dataset: {test_dataset_tmp}")
            for key, values in results.items():
                print(f"\n{key}/{test_dataset} statistics:")
                print(f"ID average accuracy: ${100*values['id_acc']:.2f}_{{{100*values['id_std']:.2f}}}$, len: {values['id_len']}")
            print('\n'+'-'*100+'\n')
    
            test_knn_diversity(models, test_dataset, k_values, results)
    
    elif('winogrande_0' == test_datasets[0]):
        dataset_sizes = {
            'winogrande_0': 100,
            'winogrande_1': 227,
            'winogrande_2': 597,
            'winogrande_3': 1000,
            'winogrande_4': 1000,
            'winogrande_5': 1000
        }
        
        combined_results = {}
        
        for test_dataset in test_datasets:
            print(f"test_dataset: {test_dataset}")
            results = calculate_accuracies('./results', methods, models, [test_dataset], [test_dataset], k_values)
            
            for key, values in results.items():
                print(f"\n{key} statistics:")
                print(f"ID average accuracy: ${100*values['id_acc']:.2f}_{{{100*values['id_std']:.2f}}}$, len: {values['id_len']}")
                
                if key not in combined_results:
                    combined_results[key] = {
                        'weighted_sum_0_3': 0,
                        'total_size_0_3': 0,
                        'weighted_sum_4_6': 0,
                        'total_size_4_6': 0
                    }
                
                dataset_num = int(test_dataset.split('_')[1])
                dataset_size = dataset_sizes[test_dataset]
                
                if 0 <= dataset_num <= 3:
                    combined_results[key]['weighted_sum_0_3'] += values['id_acc'] * dataset_size
                    combined_results[key]['total_size_0_3'] += dataset_size
                elif 4 <= dataset_num <= 5:
                    combined_results[key]['weighted_sum_4_6'] += values['id_acc'] * dataset_size
                    combined_results[key]['total_size_4_6'] += dataset_size
            
            print("\n\n\n")
        
        print("=== Winogrande 0-3 and 4-6 weighted average accuracy ===")
        for key, values in combined_results.items():
            if values['total_size_0_3'] > 0:
                avg_0_3 = values['weighted_sum_0_3'] / values['total_size_0_3']
                print(f"\n{key} (Winogrande 0-3):")
                print(f"Weighted average accuracy: ${100*avg_0_3:.2f}$")
            
            if values['total_size_4_6'] > 0:
                avg_4_6 = values['weighted_sum_4_6'] / values['total_size_4_6']
                print(f"\n{key} (Winogrande 4-6):")
                print(f"Weighted average accuracy: ${100*avg_4_6:.2f}$")
        
        print("\n=== knn vs other methods on Winogrande average difference ===")
        knn_results = {}
        for key, values in combined_results.items():
            method, model, k = key.split('/')
            if method == 'knn':
                if values['total_size_0_3'] > 0:
                    knn_results[(model, k, '0-3')] = values['weighted_sum_0_3'] / values['total_size_0_3']
                if values['total_size_4_6'] > 0:
                    knn_results[(model, k, '4-6')] = values['weighted_sum_4_6'] / values['total_size_4_6']
        
        for key, values in combined_results.items():
            method, model, k = key.split('/')
            if method != 'knn':
                if values['total_size_0_3'] > 0 and (model, k, '0-3') in knn_results:
                    avg_0_3 = values['weighted_sum_0_3'] / values['total_size_0_3']
                    diff_0_3 = avg_0_3 - knn_results[(model, k, '0-3')]
                    print(f"\n{method} vs knn (k={k}, {model}, Winogrande 0-3):")
                    print(f"Average difference: ${100*diff_0_3:.2f}$")
                
                if values['total_size_4_6'] > 0 and (model, k, '4-6') in knn_results:
                    avg_4_6 = values['weighted_sum_4_6'] / values['total_size_4_6']
                    diff_4_6 = avg_4_6 - knn_results[(model, k, '4-6')]
                    print(f"\n{method} vs knn (k={k}, {model}, Winogrande 4-6):")
                    print(f"Average difference: ${100*diff_4_6:.2f}$")
    elif('winogrande_0_new' == test_datasets[0]):
        for test_dataset in test_datasets:
            print(f"test_dataset: {test_dataset}")
            results = calculate_accuracies('./results', methods, models, [test_dataset], [test_dataset], k_values)

            for key, values in results.items():
                print(f"\n{key} statistics:")
                print(f"ID average accuracy: ${100*values['id_acc']:.2f}_{{{100*values['id_std']:.2f}}}$, len: {values['id_len']}")
            print("-"*100)
if __name__ == "__main__":
    main()