import json
import os
import sys
import re
from coalition_map import *
import csv



def parse_inference_files(folder_path="./inference_files", output_csv="parsed_inference_files.csv"):
    pattern = re.compile(
        r"^(?P<model_name>.+?)_neuron_weight_removed_from_coalition_"
        r"(?P<llm>\w+)_"
        r"(?P<layer_number>\d+)_"
        r"(?P<coalition_strategy>\w+)_"
        r"(?P<coalition_length>\d+)_"
        r"(?P<neuron_i>\d+)_"
        r"(?P<neuron_j>\d+)_\.txt$"
    )

    parsed_files = []

    for filename in os.listdir(folder_path):
        match = pattern.match(filename)
        if match:
            file_info = match.groupdict()
            # Convert numeric values to int
            file_info['layer_number'] = int(file_info['layer_number'])
            file_info['coalition_length'] = int(file_info['coalition_length'])
            file_info['neuron_i'] = int(file_info['neuron_i'])
            file_info['neuron_j'] = int(file_info['neuron_j'])
            file_info['file_name'] = filename
            file_info['llm'] = file_info['llm'].lower()
            file_info['mse'] = compute_mse(filename, file_info['llm'], folder_path)
            parsed_files.append(file_info)
    
    print(f"Parsed {len(parsed_files)} rows.")
    
    if parsed_files:
        fieldnames = [
            "file_name", "model_name", "llm", "layer_number",
            "coalition_strategy", "coalition_length", "neuron_i", "neuron_j", "mse"
        ]
        with open(output_csv, mode='w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(parsed_files)
        print(f"CSV written to {output_csv}")
    else:
        print("No matching files found.")

    return parsed_files




def load_scores(file_path):
    scores = {}
    with open(file_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 3:
                qid, docid, score = parts
                try:
                    scores[(qid, docid)] = float(score)
                except ValueError:
                    continue
    return scores

def compute_mse(file1, llm, folder):
    
    file2 = base_llm_to_gt_output[llm]
    
    scores1 = load_scores(f"{folder}/{file1}")
    scores2 = load_scores(f"{file2}")

    total_squared_error = 0.0
    count = 0

    for key in scores1:
        if key in scores2:
            diff = scores1[key] - scores2[key]
            total_squared_error += diff ** 2
            count += 1

    if count == 0:
        return None

    mse = total_squared_error / count
    return mse

def compute_mse(file1, file2):
    
    scores1 = load_scores(file1)
    scores2 = load_scores(file2)

    total_squared_error = 0.0
    count = 0

    for key in scores1:
        if key in scores2:
            diff = scores1[key] - scores2[key]
            total_squared_error += diff ** 2
            count += 1

    if count == 0:
        return None

    mse = total_squared_error / count
    return mse





def main():
    
    # folder_path = "./inference_files/"
    # parse_inference_files(folder_path)
    cqtr_gt = "dl19_sampled_subset_gt_cqtr.txt"
    mtf_gt = "dl19_sampled_subset_gt_mtf.txt"
    cp0_llama3 = "sampled_llama3_checkpoint_0_output.psg.dl19.txt"
    cp0_mistral = "sampled_mistral_checkpoint_0_output.psg.dl19.txt"
    cp0_pythia = "sampled_pythia_checkpoint_0_output.psg.dl19.txt"
    cqtr_llama3 = "sampled_llama3_cqtr_ft_output.psg.dl19.txt"
    cqtr_mistral = "sampled_mistral_cqtr_ft_output.psg.dl19.txt"
    cqtr_pythia = "sampled_pythia_cqtr_ft_output.psg.dl19.txt"
    mtf_llama3 = "sampled_llama3_mtf_ft_output.psg.dl19.txt"
    mtf_mistral = "sampled_mistral_mtf_ft_output.psg.dl19.txt"
    mtf_pythia = "sampled_pythia_mtf_ft_output.psg.dl19.txt"
    
    mse_cp0_llama3_cqtr = compute_mse(cp0_llama3, cqtr_gt)   
    mse_cp0_mistral_cqtr = compute_mse(cp0_mistral, cqtr_gt)
    mse_cp0_pythia_cqtr = compute_mse(cp0_pythia, cqtr_gt)
    
    mse_cp0_llama3_mtf = compute_mse(cp0_llama3, mtf_gt)
    mse_cp0_mistral_mtf = compute_mse(cp0_mistral, mtf_gt)
    mse_cp0_pythia_mtf = compute_mse(cp0_pythia, mtf_gt)
    
    
    mse_cqtr_llama3 = compute_mse(cqtr_llama3, cqtr_gt)
    mse_cqtr_mistral = compute_mse(cqtr_mistral, cqtr_gt)
    mse_cqtr_pythia = compute_mse(cqtr_pythia, cqtr_gt)
    mse_mtf_llama3 = compute_mse(mtf_llama3, mtf_gt)
    mse_mtf_mistral = compute_mse(mtf_mistral, mtf_gt)
    mse_mtf_pythia = compute_mse(mtf_pythia, mtf_gt)
    
    print(f"Mean Squared Error (MSE) for cp0_llama3 cqtr: {mse_cp0_llama3_cqtr:.2f}")
    print(f"Mean Squared Error (MSE) for cp0_mistral cqtr: {mse_cp0_mistral_cqtr:.2f}")
    print(f"Mean Squared Error (MSE) for cp0_pythia cqtr: {mse_cp0_pythia_cqtr:.2f}")
    print(f"Mean Squared Error (MSE) for cp0_llama3 mtf: {mse_cp0_llama3_mtf:.2f}")
    print(f"Mean Squared Error (MSE) for cp0_mistral mtf: {mse_cp0_mistral_mtf:.2f}")
    print(f"Mean Squared Error (MSE) for cp0_pythia mtf: {mse_cp0_pythia_mtf:.2f}")
    print(f"Mean Squared Error (MSE) for cqtr_llama3: {mse_cqtr_llama3:.2f}")
    print(f"Mean Squared Error (MSE) for cqtr_mistral: {mse_cqtr_mistral:.2f}")
    print(f"Mean Squared Error (MSE) for cqtr_pythia: {mse_cqtr_pythia:.2f}")
    print(f"Mean Squared Error (MSE) for mtf_llama3: {mse_mtf_llama3:.2f}")
    print(f"Mean Squared Error (MSE) for mtf_mistral: {mse_mtf_mistral:.2f}")
    print(f"Mean Squared Error (MSE) for mtf_pythia: {mse_mtf_pythia:.2f}")
    
    
    '''
    
    # Example file names
    # # example_file_name  = "inference_files/finegrained_checkpoint_experiment_llama3_r8_lora_mlp_neuron_weight_removed_from_coalition_llama_7_covariate_3_9_1250_.txt"

    # # Example usage
    # file1 = "./inference_files/sampled_examples.txt"
    # file2 = "./inference_files/run.rankllama3_coalition_A_removed.psg.dl19.txt"

    # mse = compute_mse(file1, file2)
    # if mse is not None:
    #     print(f"Mean Squared Error (MSE): {mse:.6f}")
    # else:
    #     print("No matching (queryid, docid) pairs found.")
    
    '''
    

if __name__ == "__main__":
    main()
    
    
    