import numpy as np
import random
from sentence_transformers import SentenceTransformer, util
import os 
import json
from tqdm import tqdm
from sklearn.manifold import TSNE

def encode_sentences(file_path):
    
    model = SentenceTransformer('stsb-roberta-large')
    with open(file_path, 'r') as file:
        lines = file.readlines()

    sentences = []
    for line in lines:
        data = json.loads(line)
        ans = data['answer']
        explain_start = ans.find('**Explanation:**')
        if explain_start != -1:
            explain = ans[explain_start+len('**Explanation:**'):].strip()
            sentences.append(explain)
        else:
            sentences.append(ans.strip())
        # print(explain)
        # print('-----')
        # if len(sentences) > 5:
        #     exit()
    print('len(sentences):', len(sentences))
    
    # Encode all sentences
    print('Encoding sentences...')
    encoded_sentences = []
    for sentence in tqdm(sentences):
        encoding = model.encode(sentence, convert_to_tensor=True)
        encoded_sentences.append(encoding)
    
    # Save encoded sentences
    # Convert tensor list to numpy arrays
    numpy_encodings = [encoding.cpu().numpy() for encoding in encoded_sentences]
    output_file = os.path.splitext(file_path)[0] + '_encodings.npz'
    np.savez(output_file, *numpy_encodings)
    print(f'Saved encodings to {output_file}')

    return encoded_sentences


def kmedoids(data, k, max_iter=50):
    """
    Apply K-Medoids clustering algorithm to the input data.
    
    Parameters:
    data (numpy array): The input data samples.
    k (int): The number of representative samples to return.
    max_iter (int): The maximum number of iterations for the algorithm.
    
    Returns:
    list: The indices of the representative samples.
    """
    random_seed = 10
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    m, n = data.shape
    # Randomly initialize the medoids
    medoid_indices = random.sample(range(m), k)
    medoids = data[medoid_indices, :]

    for iteration in range(max_iter):
        # Assign each point to the nearest medoid
        clusters = [[] for _ in range(k)]
        for i in range(m):
            distances = [np.linalg.norm(data[i] - medoid) for medoid in medoids]
            closest_medoid = np.argmin(distances)
            clusters[closest_medoid].append(i)

        # Update the medoids
        new_medoids = np.copy(medoids)
        for i in range(k):
            if len(clusters[i]) == 0:
                continue
            cluster_points = data[clusters[i], :]
            total_distances = np.sum(np.linalg.norm(cluster_points[:, np.newaxis] - cluster_points, axis=2), axis=1)
            new_medoids[i] = cluster_points[np.argmin(total_distances)]

        # Check for convergence
        if np.all(new_medoids == medoids):
            break
        medoids = new_medoids

    return medoid_indices

def plot_samples(data, representative_indices):
    """
        Plot all samples and highlight the representative samples in a 2D space.
        
        Parameters:
        data (numpy array): The input data samples.
        representative_indices (list): Indices of the representative samples.
    """
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(10, 8))
    
    # Check if data has more than 2 dimensions
    if data.shape[1] > 2:
        print("Data has more than 2 dimensions. Using t-SNE to reduce to 2D for visualization.")
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(data)-1))
        data = tsne.fit_transform(data)

    # Plot all samples
    plt.scatter(data[:, 0], data[:, 1], c='lightblue', marker='o', s=50, label='Regular samples')
    
    # Highlight representative samples
    plt.scatter(data[representative_indices, 0], data[representative_indices, 1], 
                c='red', marker='*', s=200, label='Representative samples')
    
    # Add labels for representative points
    for i, idx in enumerate(representative_indices):
        plt.annotate(f"Rep {i+1}", (data[idx, 0], data[idx, 1]), 
                    textcoords="offset points", xytext=(0,10), ha='center')
    
    plt.title('K-Medoids Clustering Visualization')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid(True, alpha=0.3)
    # plt.show()
    # Save the plot
    plt.savefig('kmedoids.png')
    print('Saved plot to kmedoids.png')


# Example usage
if __name__ == "__main__":
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    dataset = 'S1'
    encode_sentences(f'output/DB-{dataset}-ku2-2shot.jsonl')
    # Load the encoded data from NPZ file
    def load_encodings(npz_file):
        loaded = np.load(npz_file)
        encodings = [loaded[f'arr_{i}'] for i in range(len(loaded.files))]
        return np.array(encodings)

    # Path to files
    encodings_path = f'output/DB-{dataset}-ku2-2shot_encodings.npz'
    path_file = f'DB_{dataset}train_path.jsonl'

    # Load encodings
    print("Loading encodings...")
    encodings = load_encodings(encodings_path)
    print(f"Loaded {len(encodings)} encodings with shape {encodings.shape}")

    # Load path file with labels
    print("Loading labels from path file...")
    with open(path_file, 'r') as file:
        path_data = [json.loads(line) for line in file]

    # Group data by labels
    print("Grouping data by labels...")
    label_groups = {}
    for item in path_data:
        label = item['test_triplet'][-1]
        if label not in label_groups:
            label_groups[label] = []
        label_groups[label].append(item['id'])

    # Perform K-medoids clustering for each group
    representative_ids = []
    print(f"Found {len(label_groups)} different labels")

    for label, ids in label_groups.items():
        print(f"Processing label {label} with {len(ids)} samples")
        group_encodings = encodings[ids]
        
        # Calculate K for this group (max(10, 5% of group size))
        k = max(10, int(len(ids) * 0.05))
        
        # Apply K-medoids to get representative indices within this group
        if len(ids) > k:
            rep_indices = kmedoids(group_encodings, k)
            # Map back to original IDs
            group_rep_ids = [ids[idx] for idx in rep_indices]
        else:
            # If group is smaller than k, take all samples
            group_rep_ids = ids
        
        representative_ids.extend(group_rep_ids)
        print(f"Added {len(group_rep_ids)} representatives for label {label}")

    print(f"Total representative samples: {len(representative_ids)}")

    # Save representative IDs
    with open(f'DB_{dataset}_representative_samples.json', 'w') as f:
        json.dump(representative_ids, f)
    print("Saved representative IDs to representative_samples.json")

