import json 
import numpy as np 
import os 
import torch 
import time 
import requests 
import os, argparse, torch, yaml
from tqdm import tqdm
from sklearn_extra.cluster import KMedoids
from sklearn.cluster import SpectralClustering
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
from transformers import AutoModel, AutoImageProcessor
from joblib import Parallel, delayed
from itertools import combinations

from PIL import Image 
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import CLIPProcessor, CLIPModel

def main_(args):

    model = SentenceTransformer("all-MiniLM-L6-v2")
    # model = SentenceTransformer("all-distilroberta-v1")
    DATASET = args.dataset
    EXPERIMET_NAME = args.experiment_name
    with open(f"./inference_json_files/{EXPERIMET_NAME}.json", "r") as file:
        data = json.load(file)

    for x in data:
        for k,v in x.items():
            if isinstance(v, list):
                # print(f'{k=}')
                # print(f'{v=}')
                assert len(v) == 1
                x[k] = v[0]
        x['image_path'] = x['img_path']
        x['reasoning_answer'] = x['generated_texts']


    with open(f"./reasoning_datasets/{DATASET}_train.json", 'r') as f:
        reference_data = json.load(f)

    with open(f"./reasoning_datasets/{DATASET}_valid.json", 'r') as f:
        reference_data += json.load(f)


    print(F'data length : {len(data)}')
    print(F'reference_data length : {len(reference_data)}')

    # data = data[:30]
    # reference_data = reference_data[:30]


    ##################   generated data features  ##################
    ################################################################
    print(f'generating data text features.....')
    for sample in tqdm(data):
        sentences = sample['generated_texts'].split("\n")
        embeddings = model.encode(sentences)
        embeddings = np.array(embeddings)
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)  # Calculate norms for each row
        embeddings = embeddings / norms
        sample['thought_features'] = embeddings

    ##################   reference data features  ##################
    ################################################################

    print(f'generating reference data text features.....')
    for sample in tqdm(reference_data):
        sentences = sample['reasoning_answer'].split("\n")
        embeddings = model.encode(sentences)
        embeddings = np.array(embeddings)
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)  # Calculate norms for each row
        embeddings = embeddings / norms
        sample['thought_features'] = embeddings


    def get_deepseek_response(prompt, api_key, api_url, temperature=0.9):

        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        
        data = {
            "model": "deepseek-chat",  # Replace with the correct model name if different
            "messages": [{"role": "user", "content": prompt}],
            # "max_tokens": 350,  # Adjust as needed
            "temperature": temperature,  # Adjust as needed
        }
        
        response = requests.post(api_url, headers=headers, json=data)
        
        if response.status_code == 200:
            return response.json()["choices"][0]["message"]["content"]
        else:
            return f"Error: {response.status_code}, {response.text}"


    def get_correctness(judge_output):
        if 'yes' in judge_output.lower() and 'no' not in judge_output.lower():
            return 1
        else:
            return -1


    JUDGE_PROMPT = """Evaluate whether the model's answer matches the correct result. 

    - If it does not align, respond with 'No'.
    - If there is a logical error in the reasoning steps, respond with 'No'.
    - If the model's answer aligns with the correct result, respond with 'Yes'. 

    Provide only 'Yes' or 'No' as the output, with no explanation.

    The question is: {question}

    The model's answer is: {model_answer}

    The correct result is: {gt_answer}"""


    # API_URL = ""
    # API_KEY = ""

    # print(f'GENERATING  DEEP SEEK RESPONSES')
    # for  x in tqdm(data):
    #     model_pred = x['generated_texts'].lower().split("the final answer is:")[-1]
    #     model_pred = model_pred.strip()
    #     if model_pred.endswith('.'):
    #         model_pred = model_pred[:-1]
    #     gt_answer = x['gt_texts']

    #     while True:
    #         try:
    #             judge_output = get_deepseek_response(JUDGE_PROMPT.format(question=x['question'], model_answer=model_pred, gt_answer=gt_answer),  API_KEY, API_URL, temperature=0.9)
    #             break
    #         except Exception as e:
    #             time.sleep(0.2)
    #             print(e)
    #     is_correct = get_correctness(judge_output)
    #     x['is_correct'] = is_correct


    ######################### CLIP MODEL IMAGE FEATURES ##########################
    ##############################################################################
    # Load CLIP model and processor
    # clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    # clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    # print(f'generating data image features.....')
    # for sample in tqdm(data):
    #     image = Image.open(sample['image_path']).convert("RGB")
    #     inputs = clip_processor(images=image, return_tensors="pt")
    #     with torch.no_grad():
    #         image_features = clip_model.get_image_features(**inputs)
    #     image_features = image_features.cpu().numpy()

    #     norm = np.linalg.norm(image_features)  # Calculate the norm
    #     image_features = image_features / norm  # Divide by the norm
    #     sample['clip_image_features'] = image_features

    # print(f'generating reference data image features.....')
    # for sample in tqdm(reference_data):
    #     image = Image.open(sample['image_path']).convert("RGB")
    #     inputs = clip_processor(images=image, return_tensors="pt")
    #     with torch.no_grad():
    #         image_features = clip_model.get_image_features(**inputs)
    #     image_features = image_features.cpu().numpy()

    #     norm = np.linalg.norm(image_features)  # Calculate the norm
    #     image_features = image_features / norm  # Divide by the norm
    #     sample['clip_image_features'] = image_features


    # ########################## DINO MODEL IMAGE FEATURES ##########################
    # ###############################################################################
    # Load DINO model and processor
    dino_model = AutoModel.from_pretrained("facebook/dino-vits16")
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dino-vits16")

    dino_model.eval()  # Set model to eval mode
    print(f'generating data image features.....')
    for sample in tqdm(data):
        image = Image.open(sample['image_path']).convert("RGB")        
        inputs = dino_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = dino_model(**inputs)
        # Get the [CLS] token as the image feature (first token)
        image_features = outputs.last_hidden_state[:, 0, :]  # shape: [1, 384]
        # image_features = image_features.cpu().numpy().squeeze()  # [384]
        # Normalize
        norm = np.linalg.norm(image_features)
        image_features = image_features / norm
        sample['clip_image_features'] = image_features

    print(f'generating reference data image features.....')
    for sample in tqdm(reference_data):
        image = Image.open(sample['image_path']).convert("RGB")        
        inputs = dino_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = dino_model(**inputs)
        # Get the [CLS] token as the image feature (first token)
        image_features = outputs.last_hidden_state[:, 0, :]  # shape: [1, 384]
        # image_features = image_features.cpu().numpy().squeeze()  # [384]
        # Normalize
        norm = np.linalg.norm(image_features)
        image_features = image_features / norm
        sample['clip_image_features'] = image_features

    for x in data:
        x['all_features'] = np.concatenate([x['clip_image_features'], x['thought_features']], axis=0)
        assert x['all_features'].shape[0]-0.05<(x['all_features']**2).sum()< x['all_features'].shape[0]+0.05
    for x in reference_data:
        x['all_features'] = np.concatenate([x['clip_image_features'], x['thought_features']], axis=0)
        assert x['all_features'].shape[0]-0.05<(x['all_features']**2).sum()< x['all_features'].shape[0]+0.05

    # print(f'JUST FOR DEBUGGING DELETE AFTER DEBUGGING')
    # for x in data:
    #     x['all_features'] = x['thought_features']
    #     assert x['all_features'].shape[0]-0.05<(x['all_features']**2).sum()< x['all_features'].shape[0]+0.05
    # for x in reference_data:
    #     x['all_features'] = x['thought_features']
    #     assert x['all_features'].shape[0]-0.05<(x['all_features']**2).sum()< x['all_features'].shape[0]+0.05

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-


    def _c(ca, i, j, p, q):

        if ca[i, j] > -1:
            return ca[i, j]
        elif i == 0 and j == 0:
            ca[i, j] = np.linalg.norm(p[i]-q[j])
        elif i > 0 and j == 0:
            ca[i, j] = max(_c(ca, i-1, 0, p, q), np.linalg.norm(p[i]-q[j]))
        elif i == 0 and j > 0:
            ca[i, j] = max(_c(ca, 0, j-1, p, q), np.linalg.norm(p[i]-q[j]))
        elif i > 0 and j > 0:
            ca[i, j] = max(
                min(
                    _c(ca, i-1, j, p, q),
                    _c(ca, i-1, j-1, p, q),
                    _c(ca, i, j-1, p, q)
                ),
                np.linalg.norm(p[i]-q[j])
                )
        else:
            ca[i, j] = float('inf')

        return ca[i, j]


    def frdist_manual(p, q):
        """
        Computes the discrete Fréchet distance between
        two curves. The Fréchet distance between two curves in a
        metric space is a measure of the similarity between the curves.
        The discrete Fréchet distance may be used for approximately computing
        the Fréchet distance between two arbitrary curves,
        as an alternative to using the exact Fréchet distance between a polygonal
        approximation of the curves or an approximation of this value.

        This is a Python 3.* implementation of the algorithm produced
        in Eiter, T. and Mannila, H., 1994. Computing discrete Fréchet distance.
        Tech. Report CD-TR 94/64, Information Systems Department, Technical
        University of Vienna.
        http://www.kr.tuwien.ac.at/staff/eiter/et-archive/cdtr9464.pdf

        Function dF(P, Q): real;
            input: polygonal curves P = (u1, . . . , up) and Q = (v1, . . . , vq).
            return: δdF (P, Q)
            ca : array [1..p, 1..q] of real;
            function c(i, j): real;
                begin
                    if ca(i, j) > −1 then return ca(i, j)
                    elsif i = 1 and j = 1 then ca(i, j) := d(u1, v1)
                    elsif i > 1 and j = 1 then ca(i, j) := max{ c(i − 1, 1), d(ui, v1) }
                    elsif i = 1 and j > 1 then ca(i, j) := max{ c(1, j − 1), d(u1, vj) }
                    elsif i > 1 and j > 1 then ca(i, j) :=
                    max{ min(c(i − 1, j), c(i − 1, j − 1), c(i, j − 1)), d(ui, vj ) }
                    else ca(i, j) = ∞
                    return ca(i, j);
                end; /* function c */

            begin
                for i = 1 to p do for j = 1 to q do ca(i, j) := −1.0;
                return c(p, q);
            end.

        Parameters
        ----------
        P : Input curve - two dimensional array of points
        Q : Input curve - two dimensional array of points

        Returns
        -------
        dist: float64
            The discrete Fréchet distance between curves `P` and `Q`.

        Examples
        --------
        >>> from frechetdist import frdist
        >>> P=[[1,1], [2,1], [2,2]]
        >>> Q=[[2,2], [0,1], [2,4]]
        >>> frdist(P,Q)
        >>> 2.0
        >>> P=[[1,1], [2,1], [2,2]]
        >>> Q=[[1,1], [2,1], [2,2]]
        >>> frdist(P,Q)
        >>> 0
        """
        p = np.array(p, np.float64)
        q = np.array(q, np.float64)

        len_p = len(p)
        len_q = len(q)

        if len_p == 0 or len_q == 0:
            raise ValueError('Input curves are empty.')
            
        ca = (np.ones((len_p, len_q), dtype=np.float64) * -1)

        dist = _c(ca, len_p-1, len_q-1, p, q)
        return dist


    # ##################  DTW and Frecret Distances to reference trajectories  ##################
    # ###############################################################################

    def compute_distances(x, reference_data):
        dtw_distances = []
        frechet_distances = []
        
        for y in reference_data:
            distance, _ = fastdtw(x['all_features'], y['all_features'], dist=euclidean)
            dtw_distances.append(distance)

            distance = frdist_manual(x['all_features'], y['all_features'])
            frechet_distances.append(distance)

        return {
            'original': x,
            'dtw_distances': dtw_distances,
            'frechet_distances': frechet_distances, 
        }

    print("Calculating distances (DTW and Frechet)...")
    results = Parallel(n_jobs=-1)(
        delayed(compute_distances)(x, reference_data) for x in tqdm(data)
    )

    # Now assign results back to data
    for i, res in enumerate(results):
        data[i]['dtw_distances'] = [round(x, 5) for x in res['dtw_distances']]
        data[i]['frechet_distances'] = [round(x, 5) for x in res['frechet_distances']]


    ##################### FRECHET DISTANCE MATRIX CALCULATION #####################
    ####################################################################################

    trajectories = [x['all_features'] for x in reference_data]  # List of trajectories
    n = len(trajectories)
    frechet_dist_matrix = np.zeros((n, n))

    print(f'Frechet distance calculation.....')
    for i in tqdm(range(n)):
        for j in range(i + 1, n):
            frechet_dist_matrix[i, j] = frdist_manual(trajectories[i], trajectories[j])
            frechet_dist_matrix[j, i] = frechet_dist_matrix[i, j]  # Symmetric


    print(f'DTW distance calculation.....')
    dtw_dist_matrix = np.zeros((n, n))
    for i in tqdm(range(n)):
        for j in range(i + 1, n):
            distance, _ = fastdtw(trajectories[i], trajectories[j], dist=euclidean)
            dtw_dist_matrix[i, j] = distance
            dtw_dist_matrix[j, i] = dtw_dist_matrix[i, j]  # Symmetric
    print(f'we are here : ')
    ##################### KMEDOID with FRECHET DISTANCE CLUSTERING #####################
    ####################################################################################

    for ncluster in [10,20]:
        # Apply K-medoids clustering
        start_ = time.time()
        print(f'kmedoid clustering with frechet distance ncluster:{ncluster}....')
        kmedoids = KMedoids(n_clusters=ncluster, metric="precomputed", random_state=42)
        clusters = kmedoids.fit_predict(frechet_dist_matrix)
        end_ = time.time()
        print(f'it takes : {(end_-start_)/60} seconds to do Kmedoid clustering for n {ncluster}')

        # print(f"Cluster assignments: {clusters}")
        # Compute distances from the new trajectory to each medoid
        print(f'k-medoid w/ frechet distance calculation started ncluster:{ncluster}.....')
        for x in tqdm(data):
            new_distances = [frdist_manual(x['all_features'], trajectories[medoid]) for medoid in kmedoids.medoid_indices_]
            x[f'Kmedoid_frdist_distances_ncluster_{ncluster}'] = [round(yy, 4) for yy in new_distances]


    ##################### KMEDOID with DTW DISTANCE CLUSTERING #####################
    ####################################################################################
    for ncluster in [10,20]:
        # Apply K-medoids clustering
        start_ = time.time()
        print(f'kmedoid clustering with dtw distance ncluster:{ncluster}....')
        kmedoids = KMedoids(n_clusters=ncluster, metric="precomputed", random_state=42)
        clusters = kmedoids.fit_predict(dtw_dist_matrix)
        # print(f"Cluster assignments: {clusters}")
        end_ = time.time()
        print(f'it takes : {(end_-start_)/60} seconds to do Kmedoid clustering with dtw distance for n {ncluster}')

        # Compute distances from the new trajectory to each medoid
        print(f'k-medoid w/ dtw distance calculation started ncluster:{ncluster}.....')
        for x in tqdm(data):
            new_distances = [fastdtw(x['all_features'], trajectories[medoid], dist=euclidean)[0] for medoid in kmedoids.medoid_indices_]
            x[f'Kmedoid_dtw_distances_ncluster_{ncluster}'] = [round(yy, 4) for yy in new_distances]


    ##################### SpectralClustering with FRECHET DISTANCE CLUSTERING #####################
    ####################################################################################

    for ncluster in [10,20]:
        # Apply K-medoids clustering
        start_ = time.time()
        print(f'spectral clustering with frechet distance ncluster:{ncluster}....')
        spectral = SpectralClustering(n_clusters=ncluster, affinity="precomputed", random_state=42)
        clusters = spectral.fit_predict(np.exp(-frechet_dist_matrix))  # Convert distance to similarity
        # print(f"Cluster assignments: {clusters}")
        end_ = time.time()
        print(f'it takes : {(end_-start_)/60} seconds to do Spectrakl clustering with frechet distance for n {ncluster}')

        # Step 1: Compute spectral embeddings
        spectral_embeddings = spectral.affinity_matrix_  # This is the learned affinity matrix

        # Step 2: Compute mean embedding for each cluster
        cluster_embeddings = {c: [] for c in range(spectral.n_clusters)}
        for idx, label in enumerate(clusters):
            cluster_embeddings[label].append(spectral_embeddings[idx])

        # Compute the centroid of each cluster in spectral space
        cluster_centroids = {c: np.mean(cluster_embeddings[c], axis=0) for c in cluster_embeddings}


        # Compute distances from the new trajectory to each medoid
        print(f'spectral clustering w/ frechet distance calculation started ncluster:{ncluster}.....')

        for x in tqdm(data):
            # Step 3: Compute new trajectory's embedding
            # new_embedding1 = np.exp(-np.array([frdist_manual(x['all_features'], traj) for traj in trajectories]))
            new_embedding = np.exp(-np.array(x['frechet_distances']))
            new_embedding = new_embedding.reshape(1, -1)  # Reshape for consistency

            # Step 4: Compute distance to each cluster centroid
            distances_to_centroids = {c: np.linalg.norm(new_embedding - centroid) for c, centroid in cluster_centroids.items()}
            distances_to_centroids = [v for k,v in distances_to_centroids.items()]
            x[f'SpectralClustering_frdist_distances_ncluster_{ncluster}'] = [round(yy, 4) for yy in distances_to_centroids]

        # print(f'-----new method starts!-----')
        # def compute_distances_to_spectral_clusters(x, cluster_centroids, trajectories):
        #     new_embedding = np.exp(-np.array([frdist_manual(x['all_features'], traj) for traj in trajectories]))
        #     new_embedding = new_embedding.reshape(1, -1)  # Reshape for consistency

        #     # Step 4: Compute distance to each cluster centroid
        #     distances_to_centroids = {c: np.linalg.norm(new_embedding - centroid) for c, centroid in cluster_centroids.items()}
        #     distances_to_centroids = [v for k,v in distances_to_centroids.items()]
        #     return {
        #         'SpectralClustering_frdist_distances_ncluster': distances_to_centroids, 
        #     }

        # print("Calculating distances (DTW and Frechet)...")
        # results = Parallel(n_jobs=-1)(
        #     delayed(compute_distances)(x, cluster_centroids, trajectories) for x in tqdm(data)
        # )

        # # Now assign results back to data
        # for i, res in enumerate(results):
        #     data[i][f'SpectralClustering_frdist_distances_ncluster_{ncluster}'] = res['SpectralClustering_frdist_distances_ncluster']

    ##################### SpectralClustering with DTW DISTANCE CLUSTERING #####################
    ####################################################################################

    for ncluster in [10,20]:
        # Apply K-medoids clustering
        print(f'spectral clustering with dtw distance ncluster:{ncluster}....')

        spectral = SpectralClustering(n_clusters=ncluster, affinity="precomputed", random_state=42)
        clusters = spectral.fit_predict(np.exp(-dtw_dist_matrix))  # Convert distance to similarity
        # print(f"Cluster assignments: {clusters}")

        # Step 1: Compute spectral embeddings
        spectral_embeddings = spectral.affinity_matrix_  # This is the learned affinity matrix

        # Step 2: Compute mean embedding for each cluster
        cluster_embeddings = {c: [] for c in range(spectral.n_clusters)}
        for idx, label in enumerate(clusters):
            cluster_embeddings[label].append(spectral_embeddings[idx])

        # Compute the centroid of each cluster in spectral space
        cluster_centroids = {c: np.mean(cluster_embeddings[c], axis=0) for c in cluster_embeddings}


        # Compute distances from the new trajectory to each medoid
        for x in data:
            # Step 3: Compute new trajectory's embedding
            # new_embedding = np.exp(-np.array([fastdtw(x['all_features'], traj, dist=euclidean)[0] for traj in trajectories]))
            new_embedding = np.exp(-np.array(x['dtw_distances']))
            new_embedding = new_embedding.reshape(1, -1)  # Reshape for consistency

            # Step 4: Compute distance to each cluster centroid
            distances_to_centroids = {c: np.linalg.norm(new_embedding - centroid) for c, centroid in cluster_centroids.items()}
            distances_to_centroids = [v for k,v in distances_to_centroids.items()]
            x[f'SpectralClustering_dtw_distances_ncluster_{ncluster}'] = [round(yy, 4) for yy in distances_to_centroids]


    for x in data:
        del x['all_features']
        del x['thought_features']
        del x['clip_image_features']


    with open(f'./self_imp_processed_files/{EXPERIMET_NAME}_self_imp_processed_features.json', 'w') as file:
        json.dump(data, file)


if __name__ == "__main__":
    start_ = time.time()
    parser = argparse.ArgumentParser(
        description="Script with configurable hyperparameters"
    )
    parser.add_argument("--dataset", type=str, help="dataset")
    parser.add_argument("--experiment_name", type=str, help="experiment_name")
    args = parser.parse_args()
    main_(args)
    end = time.time()
    time_ = (end-start_)/60
    print(f'It takes {time_} minutes to finish the code')

