from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_distances
import numpy as np
import baseline as ba
import multi_agent as ma
import os
from collections import defaultdict


n_try = 5
model = SentenceTransformer('all-MiniLM-L6-v2')

for root, dirs, files in os.walk("dataset"):
    patients = dirs
    break


def run(llm):
    if llm == 'zodiac':
        rst = ma.metric_llm_inference(
            metric_path=f"dataset/{patient}/metrics.json",
            ecg_findings=baseline.ecg_findings,
            model_path="../../../huggingface/Meta-Llama-3.1-8B-Instruct"
        )
    elif llm == 'gemini-pro':
        rst = ""
        for llm in ['gemini-pro', 'gemini', 'gemini-flash', ]:
            try:
                rst = baseline.gf4_inference(llm)
            except:
                pass
            if "Interpretation" in rst and "message exceed" not in rst:
                break
        
    elif llm == 'gpt-4o':
        rst = ""
        for llm in ['gpt-4-turbo', 'gpt-4o', 'gpt-4', 'gpt-4o-mini']:
            try:
                rst = baseline.gf4_inference(llm)
            except:
                pass
            if "Interpretation" in rst and "message exceed" not in rst:
                break
        
    elif llm == 'llama-3.1-405b':
        rst = ""
        for llm in ['llama-3.1-405b', 'llama-3.1-70b', 'llama-3-70b']:
            try:
                rst = baseline.gf4_inference(llm)
            except:
                pass
            if "Interpretation" in rst and "message exceed" not in rst:
                break

        
    elif llm == "mixtral-8x22b":
        rst = ""       
        for llm in ['mixtral-8x22b', 'mixtral-8x7b', 'mixtral-7b']:
            try:
                rst = baseline.gf4_inference(llm)
            except:
                pass
            if "Interpretation" in rst and "message exceed" not in rst:
                break
            
    elif llm == "med42":
        rst = ""  
        try:
            rst = baseline.med42_inference()
        except:
            pass
        
    return rst


for llm in ['zodiac', 
            'gpt-4o', 
            'gemini-pro',
            'llama-3.1-405b',
            'mixtral-8x22b',
            'biogpt',
            'meditron'
            'med42',
        ]:
    all_data = []
    for patient in patients:
        baseline = ba.BaselineInference(
            metric_path=f"dataset/{patient}/metrics.json",
            image_path=f"dataset/{patient}/ecg.json"
        )

        texts = []
        for _ in range(n_try):
            rst = run(llm)
            texts.append(rst)
        
        embeddings = model.encode(texts)
        distance_matrix = cosine_distances(embeddings)
        pairwise_distances = distance_matrix[np.triu_indices(len(texts), k=1)]
        variance_distance = np.var(pairwise_distances)

        all_data.append(variance_distance)
        
    with open(f'stable/{llm}.npy', 'wb') as f:
        np.save(f, all_data)