import json 
import numpy as np 
import os 
import torch 
import time 
import requests 
import os, argparse, torch, yaml

from tqdm import tqdm
from sklearn_extra.cluster import KMedoids
from sklearn.cluster import SpectralClustering
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
from transformers import AutoModel, AutoImageProcessor
from joblib import Parallel, delayed
from itertools import combinations

from PIL import Image 
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import CLIPProcessor, CLIPModel

def main_(args):

    model = SentenceTransformer("all-MiniLM-L6-v2")
    # model = SentenceTransformer("all-distilroberta-v1")
    DATASET = args.dataset
    EXPERIMET_NAME = args.experiment_name
    with open(f"./inference_json_files/{EXPERIMET_NAME}.json", "r") as file:
        data = json.load(file)

    for x in data:
        for k,v in x.items():
            if isinstance(v, list):
                assert len(v) == 1
                x[k] = v[0]
        x['image_path'] = x['img_path']
        x['reasoning_answer'] = x['generated_texts']



    with open(f"./reasoning_datasets/{DATASET}_train.json", 'r') as f:
        reference_data = json.load(f)

    with open(f".reasoning_datasets/{DATASET}_valid.json", 'r') as f:
        reference_data += json.load(f)


    print(F'data length : {len(data)}')
    print(F'reference_data length : {len(reference_data)}')


    def get_deepseek_response(prompt, api_key, api_url, temperature=0.9):

        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        
        data = {
            "model": "deepseek-chat",  # Replace with the correct model name if different
            "messages": [{"role": "user", "content": prompt}],
            # "max_tokens": 350,  # Adjust as needed
            "temperature": temperature,  # Adjust as needed
        }
        
        response = requests.post(api_url, headers=headers, json=data)
        
        if response.status_code == 200:
            return response.json()["choices"][0]["message"]["content"]
        else:
            return f"Error: {response.status_code}, {response.text}"


    def get_correctness(judge_output):
        if 'yes' in judge_output.lower() and 'no' not in judge_output.lower():
            return 1
        else:
            return -1


    JUDGE_PROMPT = """Evaluate whether the model's answer matches the correct result. 

    - If it does not align, respond with 'No'.
    - If there is a logical error in the reasoning steps, respond with 'No'.
    - If the model's answer aligns with the correct result, respond with 'Yes'. 

    Provide only 'Yes' or 'No' as the output, with no explanation.

    The question is: {question}

    The model's answer is: {model_answer}

    The correct result is: {gt_answer}"""


    API_URL = ""
    API_KEY = ""

    print(f'GENERATING  DEEP SEEK RESPONSES')
    for  x in tqdm(data):
        model_pred = x['generated_texts'].lower().split("the final answer is:")[-1]
        model_pred = model_pred.strip()
        if model_pred.endswith('.'):
            model_pred = model_pred[:-1]
        gt_answer = x['gt_texts']

        while True:
            try:
                judge_output = get_deepseek_response(JUDGE_PROMPT.format(question=x['question'], model_answer=model_pred, gt_answer=gt_answer),  API_KEY, API_URL, temperature=0.9)
                break
            except Exception as e:
                time.sleep(0.2)
                print(e)
        is_correct = get_correctness(judge_output)
        x['is_correct'] = is_correct


    with open(f'./self_imp_processed_files/{EXPERIMET_NAME}_self_imp_processed_answer_check.json', 'w') as file:
        json.dump(data, file)


if __name__ == "__main__":
    start_ = time.time()
    parser = argparse.ArgumentParser(
        description="Script with configurable hyperparameters"
    )
    parser.add_argument("--dataset", type=str, help="dataset")
    parser.add_argument("--experiment_name", type=str, help="experiment_name")
    args = parser.parse_args()
    main_(args)
    end = time.time()
    time_ = (end-start_)/60
    print(f'It takes {time_} minutes to finish the code')

