import os
import io
import time
import json
import torch
import numpy as np
import shutil
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from PIL import Image, ImageFile
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from collections import Counter


def cluster(args, goal):
    output_dir = os.path.join(os.getcwd(), f"../../output/tagging/{args.lvlm}_{args.use_option}")

    all_chunk_data = []
    chunk_dir = os.path.join(output_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r") as f:
            all_chunk_data += json.load(f)
    tags = [item['tag'][goal] for item in all_chunk_data]
    tags = tags[:args.max_data] if args.max_data != -1 else tags
    goal = goal.lower()

    ### Getting embeddings
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    embeds = model.encode(tags)
    scaler = StandardScaler()
    embeds = scaler.fit_transform(embeds)

    ### Clustering
    if args.cluster_method == "KMeans":
        kmeans = KMeans(n_clusters=args.n_clusters, random_state=0)
        kmeans.fit(embeds)
        labels = kmeans.labels_
        centers = kmeans.cluster_centers_
        num_clusters = args.n_clusters
    else:
        raise NotImplementedError

    ### Saving results
    cluster_info = []
    cluster_info_file = os.path.join(output_dir, f"cluster_{args.lvlm}_{goal}_{args.cluster_method}_{args.n_clusters}.json")
    for cluster_idx in range(num_clusters):
        cluster_diagram_indices = np.where(labels == cluster_idx)[0]
        cluster_points = embeds[cluster_diagram_indices]
        distances = np.linalg.norm(cluster_points - centers[cluster_idx], axis=1)
        center_diagram_idx = cluster_diagram_indices[np.argmin(distances)]

        cluster_info.append({
            "cluster_idx": cluster_idx,
            "cluster_name": tags[center_diagram_idx],
            "center_diagram_idx": int(center_diagram_idx),
            "diagram_list": ", ".join(map(str, cluster_diagram_indices)),
            "num_diagrams: ": len(cluster_diagram_indices)
        })
    with open(cluster_info_file, "w") as f:
        json.dump(cluster_info, f, indent=4)

    if args.save_images == "image":
        if args.dataset == "WikiWeb":
            data_folder = args.data_folder
            dataset_path = os.path.join(data_folder, "WikiWeb/wiki_data.pkl")
            with open(dataset_path, 'rb') as file:
                dataset = pickle.load(file)
        else:
            raise NotImplementedError

        images_path = os.path.join(output_dir, f"cluster_{args.lvlm}_{goal}_{args.cluster_method}_{args.n_clusters}")
        if os.path.exists(images_path):
            shutil.rmtree(images_path)
        os.makedirs(images_path)
        for cluster_idx in range(num_clusters):
            cluster_path = os.path.join(images_path, f"cluster_{cluster_idx}")
            os.makedirs(cluster_path)
        for idx, label in tqdm(enumerate(labels), total=len(labels), desc="Copying images to clusters"):
            if label < 0:
                continue
            image = Image.open(io.BytesIO(dataset[idx]['image_bytes']))
            img_path = os.path.join(images_path, f"cluster_{label}/{idx}.png")
            image.convert('RGB').save(img_path, "PNG", optimize=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--dataset', type=str, default="WikiWeb", help="[WikiWeb]")
    parser.add_argument('--data_folder', type=str, default="../../../Datasets")
    parser.add_argument('--lvlm', type=str, default="molmo", help="[molmo, llama]")
    parser.add_argument('--use_option', default='no-option', help="[option, no-option]")
    parser.add_argument('--save_images', type=str, default="no-image", help="[image, no-image]")
    parser.add_argument('--max_data', type=int, default=-1)
    parser.add_argument('--cluster_method', type=str, default="KMeans")
    parser.add_argument('--n_clusters', type=int, default=50)
    args = parser.parse_args()

    start_time = datetime.now().strftime("%H:%M:%S")
    cluster(args, "Domain")
    cluster(args, "Type")
    end_time = datetime.now().strftime("%H:%M:%S")

    print("########## Information ##########")
    print(f"Starting time: {start_time}")
    print(f"Ending time: {end_time}")
