import os
import numpy as np
from collections import defaultdict

def calculate_accuracies(base_path, methods, models, test_datasets, train_datasets, k_values):
    results = {}
             
    for method in methods:
        for model in models:
            for k in k_values:
                id_accs = []
                ood_accs = []
                for test_dataset in test_datasets:
                    file_path = os.path.join(base_path, method, model, test_dataset, str(k), "all-roberta-large-v1.txt")
                    try:
                        with open(file_path, 'r') as f:
                            lines = f.readlines()
                            total_lines = len(lines)
                            for i, line_idx in enumerate(range(1, total_lines+1, 3)):
                                try:
                                    dataset_line = lines[line_idx].strip()
                                    current_dataset = dataset_line.split('train_dataset: ')[1]
                                    acc = float(lines[line_idx + 1].strip().split('accuracy: ')[1])
                                    
                                    if current_dataset == test_dataset:
                                        id_accs.append(acc)
                                    elif current_dataset in train_datasets:
                                        ood_accs.append(acc)
                                except (IndexError, ValueError) as e:
                                    print(f"Warning: error at line {line_idx} in {file_path}: {e}")
                                    continue
                    except FileNotFoundError:
                        print(f"Warning: file not found - {file_path}")
                        continue

                if id_accs:
                    key = f"{method}/{model}/{k}"
                    results[key] = {
                        'id_acc': np.mean(id_accs),
                        'id_std': np.std(id_accs),
                        'id_len': len(id_accs),
                        'id_max': np.max(id_accs),
                        'id_min': np.min(id_accs),
                        'ood_acc': None,
                        'ood_std': None,
                        'ood_len': None
                    }
                if ood_accs:
                    key = f"{method}/{model}/{k}"
                    results[key]['ood_acc'] = np.mean(ood_accs)
                    results[key]['ood_std'] = np.std(ood_accs)
                    results[key]['ood_len'] = len(ood_accs)
                    results[key]['ood_max'] = np.max(ood_accs)
                    results[key]['ood_min'] = np.min(ood_accs)
    
    return results

def test_knn_diversity(models, test_datasets, k_values, model_method_dataset_results):

    for model in models:
        for test_dataset in test_datasets:
            for k in k_values:
                acc_0 = []
                acc_1 = []
                ran = range(1, 5)
                for i in ran:
                    acc_0.append(model_method_dataset_results[model][f'knn_diversity_{i/10}'][test_dataset][int(k)])
                acc_0.append(model_method_dataset_results[model][f'knn'][test_dataset][int(k)])
                ran = range(5, 10)
                for i in ran:
                    acc_1.append(model_method_dataset_results[model][f'knn_diversity_{i/10}'][test_dataset][int(k)])
                print(f"{model} - {test_dataset} - {k}: {((np.mean(acc_0) - np.mean(acc_1)) / np.mean(acc_0) )*100:.2f}\% ")

def main():

    methods = ['knn', 'knn_diversity',  'diversity',  'random', ]

    models = ["llama-3.1-8b", 'gemma-2-9b', 'mistral-7b-v0.3']
    
    test_datasets = ['glue-sst2']

    train_datasets = test_datasets
    
    k_values = [4, 8]
    
    model_method_dataset_results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
    for test_dataset in test_datasets:
        for train_dataset in train_datasets:
            results = calculate_accuracies('./results', methods, models, [test_dataset], [train_dataset], k_values)

            if test_dataset == train_dataset:
                for key, values in results.items():
                    method, model, k = key.split('/')
                    model_method_dataset_results[model][method][test_dataset][int(k)] = values['id_acc']

            print(f"test_dataset: {test_dataset}, train_dataset: {train_dataset}")
            if test_dataset == train_dataset:
                for key, values in results.items():
                    print(f"\n{key}'s results:",f"ID average accuracy: ${100*values['id_acc']:.2f}_{{{100*values['id_std']:.2f}}}$, id_len: {values['id_len']}")
                    
            else:
                for key, values in results.items():
                    print(f"\n{key}'s results:",f"OOD average accuracy: ${100*values['ood_acc']:.2f}_{{{100*values['ood_std']:.2f}}}$, ood_len: {values['ood_len']}")
            print("\n\n\n\n")    

if __name__ == "__main__":
    main()