import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
from pathlib import Path

def visualize_3d_projections(data_path, out_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    vectors = []
    labels = []
    info = []

    for p in data.get('papers', []):
        vectors.append(p['vector'])
        labels.append("Paper")
        info.append("")

    for c in data.get('ckm', []):
        vectors.append(c['vector'])
        labels.append("CKM")
        info.append("CKM: " + c['id'][:8])

    for b in data.get('batch', []):
        vectors.append(b['vector'])
        labels.append("Batch")
        info.append(b['id'])

    if not vectors:
        print("No vectors found. Cannot generate plot.")
        return

    X = np.array(vectors)
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(X)

    plt.figure(figsize=(11, 8))

    for i, l in enumerate(labels):
        if l == "Paper":
            plt.scatter(reduced[i, 0], reduced[i, 1], c='#CCCCCC', marker='o', s=60, edgecolors='none', alpha=0.6)
        elif l == "CKM":
            plt.scatter(reduced[i, 0], reduced[i, 1], c='#2E86AB', marker='*', s=450, edgecolors='black', alpha=0.9)
            plt.annotate(info[i], (reduced[i, 0] + 0.015, reduced[i, 1] + 0.015), fontsize=11, color='#283D3B', fontweight='bold')
        elif l == "Batch":
            plt.scatter(reduced[i, 0], reduced[i, 1], c='#D64933', marker='X', s=300, edgecolors='black', alpha=0.9)
            plt.annotate(info[i], (reduced[i, 0] + 0.015, reduced[i, 1] - 0.025), fontsize=11, color='#92140C', fontweight='bold')

    plt.scatter([], [], c='#CCCCCC', marker='o', s=60, label='Published Papers (Vector Space)', alpha=0.6)
    plt.scatter([], [], c='#2E86AB', marker='*', s=400, edgecolors='black', label='CKM Incremental Approach')
    plt.scatter([], [], c='#D64933', marker='X', s=250, edgecolors='black', label='Batch God-Mode Baseline')

    plt.title("Semantic Space: Papers vs CKM vs Batch Hypotheses (PCA Projection)", fontsize=14, fontweight='bold', pad=15)
    plt.xlabel("PCA Primary Axis")
    plt.ylabel("PCA Secondary Axis")
    plt.legend(loc='lower right', frameon=True, shadow=True)
    plt.grid(True, linestyle='--', alpha=0.4)

    plt.savefig(out_path, dpi=300, bbox_inches='tight')
    print(f"Plot completely rendered and saved to {out_path}")

if __name__ == "__main__":
    # Example usage
    data_file = Path('metabolism/plot_data.json')
    output_file = Path('metabolism/embedding_space.png')
    if data_file.exists():
        visualize_3d_projections(data_file, output_file)
    else:
        print(f"Data file {data_file} not found.")
