import os
import json
import re
import statistics
import numpy as np
from utils.dataset_utils import get_dataset
from collections import defaultdict
import itertools
import math

def calculate_accuracy(gold_answers, pred_answers, file_dir, test_dataset, model, datas):
    correct = 0
    total = len(gold_answers)
    count = defaultdict(lambda: {'correct': 0, 'total': 0})
    
    for gold_answer, pred_answer, data in zip(gold_answers, pred_answers, datas):
        count_num = data['support'].count(".")
        if(count_num <= 5):
            count_num = 1
        elif(count_num > 5):
            count_num = 2
        count[count_num]['total'] += 1
        if "sciq" in test_dataset:
            if gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower():
                correct += 1
                count[count_num]['correct'] += 1
        elif "squad" in test_dataset:

            pred_clean = pred_answer.replace(" ", "")
            if isinstance(gold_answer, str):
                gold_answer = [gold_answer]
            if any(pred_clean.lower() == ans.replace(" ", "").lower() for ans in gold_answer):
                correct += 1
                count[count_num]['correct'] += 1
    count_accuracies = {}
    for group, stats in sorted(count.items(), key=lambda x: x[0]):
        acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        count_accuracies[group] = acc
    
    return correct / total if total > 0 else 0, count_accuracies

def read_jsonl(file_dir, key='question'):
    results = []
    with open(file_dir, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            results.append(data[key])
    return results

def test_knn_diversity(models, test_datasets, k_values, results):

    for model in models:

        for test_dataset in test_datasets:
            for k in k_values:
                acc_0 = []
                acc_1 = []

                ran = range(1, 5)
                for i in ran:
                    acc_0.append(results[model][f'knn_diversity_{i/10}'][test_dataset][int(k)])
                acc_0.append(results[model][f'knn'][test_dataset][int(k)])
                ran = range(5, 10)
                for i in ran:
                    acc_1.append(results[model][f'knn_diversity_{i/10}'][test_dataset][int(k)])

                print(f"{model} - {test_dataset} - {k}: {((np.mean(acc_0) - np.mean(acc_1)) / np.mean(acc_0) )*100:.2f}\% ")

def plot_hard_example(model, embs, ks, methods, test_datasets):
    test_dataset = test_datasets[0]
    emb = embs[0]
    k = ks[0]
    method = methods[0]
    
    acc_all = 0

    _, _, datas, gold_answers = get_dataset(dataset=test_dataset, load_from_local=True)
    
    base_dir = './results'
    file_dir = f"{base_dir}/{method}/{model}/{test_dataset}/{test_dataset}/{k}/0/0/{emb}.jsonl"
    
    if not os.path.exists(file_dir):
        print(f"File {file_dir} does not exist")
        return
    
    pred_answers = read_jsonl(file_dir=file_dir, key="answer")
    
    if len(pred_answers) != len(gold_answers):
        print(f"Warning: Number of predicted answers in file {file_dir} does not match dataset size")
        return
    
    sample_accuracies = []
    for gold_answer, pred_answer in zip(gold_answers, pred_answers):
        if "sciq" in test_dataset:
            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
        elif "squad" in test_dataset:
            pred_clean = pred_answer.replace(" ", "")
            if isinstance(gold_answer, str):
                gold_answer = [gold_answer]
            is_correct = any(pred_clean.lower() == ans.replace(" ", "").lower() for ans in gold_answer)
        else:
            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
        
        acc = 1 if is_correct else 0
        acc_all += acc
        sample_accuracies.append(acc)
    
    output_file = f"./data/{test_dataset}/{test_dataset}_acc_{model}.jsonl"
    with open(output_file, 'w', encoding='utf-8') as f:
        for i, (data, acc) in enumerate(zip(datas, sample_accuracies)):
            result = {
                "id": i,
                "question": data.get("question", ""),
                "support": data.get("support", ""),
                "gold_answer": gold_answers[i],
                "pred_answer": pred_answers[i],
                "acc": acc
            }
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    
    print(f"Sample accuracy results saved to {output_file}")
    print(f"acc_all: {acc_all/len(sample_accuracies)}")

def test_hard_example(models, embs, ks, methods, test_datasets, fine_tune_model):
    test_dataset = test_datasets[0]
    
    acc_file = f"./data/{test_dataset}/{test_dataset}_acc_{fine_tune_model}.jsonl"
    if not os.path.exists(acc_file):
        print(f"File {acc_file} does not exist")
        return
    
    correct_examples = []
    incorrect_examples = []
    
    with open(acc_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            if data['acc'] == 1:
                correct_examples.append(data)
            else:
                incorrect_examples.append(data)
    
    print(f"Fine-tuned model {fine_tune_model} on {test_dataset}:")
    print(f"Correct samples: {len(correct_examples)}")
    print(f"Incorrect samples: {len(incorrect_examples)}")
    
    results = {}
    
    for model in models:
        results[model] = {"correct": {}, "incorrect": {}}
        
        for method in methods:
            results[model]["correct"][method] = {}
            results[model]["incorrect"][method] = {}
            
            for k in ks:
                pred_file = f"./results/{method}/{model}/{test_dataset}/{test_dataset}/{k}/0/0/{embs[0]}.jsonl"
                if not os.path.exists(pred_file):
                    print(f"File {pred_file} does not exist")
                    continue
                
                pred_answers = read_jsonl(file_dir=pred_file, key="answer")
                
                correct_acc = 0
                for example in correct_examples:
                    idx = example['id']
                    if idx < len(pred_answers):
                        gold_answer = example['gold_answer']
                        pred_answer = pred_answers[idx]
                        
                        if "sciq" in test_dataset:
                            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
                        elif "squad" in test_dataset:
                            pred_clean = pred_answer.replace(" ", "")
                            if isinstance(gold_answer, str):
                                gold_answer = [gold_answer]
                            is_correct = any(pred_clean.lower() == ans.replace(" ", "").lower() for ans in gold_answer)
                        else:
                            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
                        
                        if is_correct:
                            correct_acc += 1
                
                incorrect_acc = 0
                for example in incorrect_examples:
                    idx = example['id']
                    if idx < len(pred_answers):
                        gold_answer = example['gold_answer']
                        pred_answer = pred_answers[idx]
                        
                        if "sciq" in test_dataset:
                            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
                        elif "squad" in test_dataset:
                            pred_clean = pred_answer.replace(" ", "")
                            if isinstance(gold_answer, str):
                                gold_answer = [gold_answer]
                            is_correct = any(pred_clean.lower() == ans.replace(" ", "").lower() for ans in gold_answer)
                        else:
                            is_correct = gold_answer.replace(" ", "").lower() == pred_answer.replace(" ", "").lower()
                        
                        if is_correct:
                            incorrect_acc += 1
                
                correct_avg_acc = correct_acc / len(correct_examples) if correct_examples else 0
                incorrect_avg_acc = incorrect_acc / len(incorrect_examples) if incorrect_examples else 0
                
                results[model]["correct"][method][k] = correct_avg_acc
                results[model]["incorrect"][method][k] = incorrect_avg_acc
    
    print("\nTest Results:")
    for model in results:
        print(f"\nModel: {model}")
        print("On samples correctly classified by fine-tuned model:")
        for method in results[model]["correct"]:
            for k in results[model]["correct"][method]:
                acc = results[model]["correct"][method][k]
                print(f"  Method: {method}, k: {k}, Accuracy: {acc:.4f}")
        
        print("On samples incorrectly classified by fine-tuned model:")
        for method in results[model]["incorrect"]:
            for k in results[model]["incorrect"][method]:
                acc = results[model]["incorrect"][method][k]
                print(f"  Method: {method}, k: {k}, Accuracy: {acc:.4f}")


def main():
    models = ["llama-3.1-8b", "gemma-2-9b"]
    embs = ["all-roberta-large-v1"]
    ks = [4,8]
    methods = ["knn","diversity","knn_diversity","random","k_means"]
    base_dir = './results'
    test_datasets = ["squad"]
    train_datasets = test_datasets

    seed = 1 if ks == [0] else 10
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))

    for test_dataset in test_datasets:
        print(f"test_dataset: {test_dataset}")
        test_dataset = [test_dataset]
        accuracy_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
        count_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))))
        param_combinations = itertools.product(
            train_datasets, test_dataset, models, embs, ks, methods
        )

        for train_dataset, test_dataset, model, emb, k, method in param_combinations:
            if(test_dataset != train_dataset):
                continue
            _seed = 1 if 'knn' in method or 'k_means' in method else seed
            _, _, datas, gold_answers = get_dataset(dataset=test_dataset, load_from_local=True)
            for i in range(_seed):
                if('knn' in method or 'k_means' in method):
                    if(k < 4):
                        permutation = math.factorial(k)
                    else:
                        permutation = 10
                else:
                    permutation = 1
                for perm in range(permutation):
                    file_dir = f"{base_dir}/{method}/{model}/{test_dataset}/{train_dataset}/{k}/{perm}/{i}/{emb}.jsonl"
                    
                    if not os.path.exists(file_dir):
                        print(f"File {file_dir} does not exist")
                        continue
                        
                    pred_answers = read_jsonl(file_dir=file_dir, key="answer")
                    if len(pred_answers) != len(gold_answers):
                        print(f"Warning: Number of predicted answers in file {file_dir} does not match dataset size")
                        
                    accuracy, count_accuracies = calculate_accuracy(gold_answers, pred_answers, file_dir, test_dataset, model, datas)
                    result_type = 'id' if train_dataset == test_dataset else 'ood'
                    accuracy_dict[model][result_type][method][k].append(accuracy)
                    results[model][method][test_dataset][k].append(accuracy)

                    for group, acc in count_accuracies.items():
                        count_dict[model][result_type][method][k][group].append(acc)
        for model in models:
            print(f"Model: {model}")
            print('-'*100)
            print("id:")
            if(accuracy_dict[model]['id']):
                for method in methods:
                    for k in ks:
                        try:
                            if(k==0):
                                print(f"Method: {method}, k: {k}, Average Accuracy: ${100 * sum(accuracy_dict[model]['id'][method][k]) / len(accuracy_dict[model]['id'][method][k]):.2f}$")
                            else:
                                if(k == 1 and "knn" in method):
                                    print(f"Method: {method}, k: {k}, Average Accuracy: ${100 * sum(accuracy_dict[model]['id'][method][k]) / len(accuracy_dict[model]['id'][method][k]):.2f}_{{0}}$")
                                else:
                                    print(f"Method: {method}, k: {k}, Average Accuracy: ${100 * sum(accuracy_dict[model]['id'][method][k]) / len(accuracy_dict[model]['id'][method][k]):.2f}_{{{100*statistics.stdev(accuracy_dict[model]['id'][method][k]):.2f}}}$")
                        
                        
                        except:
                            print(f"{method}, {k}")

if __name__ == "__main__":

    main()