import torch
import torch.nn.functional as F
import numpy as np
import json
import os
from sklearn.cluster import SpectralClustering
from sklearn.metrics import silhouette_score

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from copy_data_loader import get_all_slot_description

from transformers import T5Tokenizer, T5Model, T5EncoderModel

random_seed = 3407

def load_encoder(model_name="models--t5-small"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = T5Tokenizer.from_pretrained(model_name)
    full_t5_model = T5Model.from_pretrained(model_name)
    encoder = T5EncoderModel(full_t5_model.config)
    encoder.encoder.load_state_dict(full_t5_model.encoder.state_dict())
    encoder.encoder.embed_tokens.weight = full_t5_model.shared.weight
    
    return encoder.to(device), tokenizer, device


def get_embedding(text, encoder, tokenizer, device):
    input_text = f"<pad> {text} <eos>"
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(device)
    with torch.no_grad():
        outputs = encoder(**inputs)
        # Take output from last hidden layer and calculate mean
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
    return embedding

def domain_slot_clustering(dataset='multiwoz'):
    # Load T5 encoder
    encoder, tokenizer, device = load_encoder()
    encoder.eval()  # Set to evaluation mode, disable random operations
    # Ensure PyTorch computation determinism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Get domain-slot information from get_all_slot_description
    slot_data = get_all_slot_description(dataset)
    if not slot_data:
        raise ValueError(f"No slot data found for dataset {dataset}")
    
    # Extract domain list
    domains = list({item['domain'] for item in slot_data})
    num_domains = len(domains)
    
    # Generate slot embeddings (using question as slot description)
    slot_embeddings = [get_embedding(item['question'], encoder, tokenizer, device) for item in slot_data]
    slot_embeddings = torch.stack(slot_embeddings)
    # L2 normalize embedding vectors to ensure correct cosine similarity calculation
    embeddings = F.normalize(slot_embeddings, p=2, dim=1)
    slot_similarities = torch.matmul(embeddings, embeddings.T)
    
    # Convert similarity matrix to numpy array (only clustering step needs CPU)
    similarity_matrix = slot_similarities.cpu().numpy()
    # Convert similarity matrix to distance matrix (1-cosine similarity)
    # Convert matrix to distance matrix and ensure non-negative
    distance_matrix = 1 - similarity_matrix
    distance_matrix = np.maximum(distance_matrix, 0.0)  # Truncate negative numbers to 0
    np.fill_diagonal(distance_matrix, 0)  # Keep diagonal as 0
    np.fill_diagonal(similarity_matrix, 0)

    best_score = -1
    best_labels = None
    best_n_clusters = 0

    # Find optimal number of clusters
    for n_clusters in range(3, min(5, num_domains)):
        clustering = SpectralClustering(
            n_clusters=n_clusters,
            affinity='precomputed',  # Use precomputed similarity matrix
            random_state=random_seed,
            assign_labels='cluster_qr'  # Extract clusters directly from spectral clustering eigenvectors
        )
        slot_labels = clustering.fit_predict(similarity_matrix)
        
        score = silhouette_score(distance_matrix, slot_labels, metric='precomputed', random_state=random_seed)
        print(n_clusters, score)
        if score > best_score:
            best_score = score
            best_labels = slot_labels
            best_n_clusters = n_clusters
    
    # Organize clustering results
    result = []
    for i, item in enumerate(slot_data):
        result.append({
            "domain": item["domain"],
            "slot": item["slot"],
            "question": item["question"],
            "cluster_label": int(best_labels[i])
        })
    print("Domain-slot clustering cluster count:", best_n_clusters)
    return result, best_n_clusters, slot_embeddings

def domain_clustering(dataset='multiwoz'):
    # Load T5 encoder
    encoder, tokenizer, device = load_encoder()
    encoder.eval()  # Set to evaluation mode, disable random operations
    # Ensure PyTorch computation determinism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Get domain-slot information from get_all_slot_description
    slot_data = get_all_slot_description(dataset)
    if not slot_data:
        raise ValueError(f"No slot data found for dataset {dataset}")
    
    # Extract domain list
    domains = list({item['domain'] for item in slot_data})
    num_domains = len(domains)
    
    # Generate slot embeddings (using question as slot description)
    domain_embeddings = [get_embedding(item, encoder, tokenizer, device) for item in domains]
    domain_embeddings = torch.stack(domain_embeddings)
    # L2 normalize embedding vectors to ensure correct cosine similarity calculation
    embeddings = F.normalize(domain_embeddings, p=2, dim=1)
    domain_similarities = torch.matmul(embeddings, embeddings.T)
    embeddings = F.normalize(domain_embeddings, p=2, dim=1)
    domain_similarities = torch.matmul(embeddings, embeddings.T)
    
    # Convert similarity matrix to numpy array (only clustering step needs CPU)
    similarity_matrix = domain_similarities.cpu().numpy()
    # Convert similarity matrix to distance matrix (1-cosine similarity)
    distance_matrix = 1 - similarity_matrix
    distance_matrix = np.maximum(distance_matrix, 0.0)  # Truncate negative numbers to 0
    # Set diagonal to 0 to eliminate self-loop effects, standard preprocessing for spectral clustering
    np.fill_diagonal(distance_matrix, 0)
    np.fill_diagonal(similarity_matrix, 0)

    best_score = -1
    best_labels = None
    best_n_clusters = 0

    # Find optimal number of clusters
    for n_clusters in range(2, min(4, num_domains)):
        clustering = SpectralClustering(
            n_clusters=n_clusters,
            affinity='precomputed',  # Use precomputed similarity matrix
            random_state=random_seed,
            assign_labels='cluster_qr'  # Extract clusters directly from spectral clustering eigenvectors
        )
        domain_labels = clustering.fit_predict(similarity_matrix)
        
        score = silhouette_score(distance_matrix, domain_labels, metric='precomputed', random_state=random_seed)
        print(n_clusters, score)


        if score > best_score:
            best_score = score
            best_labels = domain_labels
            best_n_clusters = n_clusters
    
    # Organize clustering results
    result = []
    for i, item in enumerate(domains):
        result.append({
            "domain": item,
            "cluster_label": int(best_labels[i])
        })
    print("Domain clustering cluster count:", best_n_clusters)
    return result, best_n_clusters, domain_embeddings

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Domain-slot joint clustering using T5 embeddings and spectral clustering')
    parser.add_argument('--dataset', type=str, default='multiwoz', choices=['multiwoz', 'sgd'],
                        help='Dataset name (multiwoz or sgd)')
    parser.add_argument('--output', type=str, default='domain_slot_clusters.json',
                        help='Output file path for clustering results')
    args = parser.parse_args()

    # Set random seed for reproducibility
    random_seed = 43407
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)  # Control Python hash randomness
    
    # Execute clustering
    clusters, best_n_clusters, slot_embeddings = domain_slot_clustering(args.dataset)
    #clusters, best_n_clusters, domain_embeddings = domain_clustering(args.dataset)
    #print(clusters)
    #exit(0)
    
    # Save results
    with open(args.output, 'w') as f:
        json.dump(clusters, f, indent=2)
    print(f"Clustering completed. Results saved to {args.output}")

    import networkx as nx
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    from sklearn.preprocessing import LabelEncoder
    # Create bipartite graph
    G = nx.Graph()

    # Store edge clustering labels
    edge_clusters = {}

    # Add domain and slot nodes
    domains = set()
    slots = set()
    for item in clusters:
        domain = item['domain']
        slot = item['slot']
        cluster_label = item['cluster_label']
        domains.add(domain)
        slots.add(slot)
        G.add_node(domain, bipartite=0)
        G.add_node(slot, bipartite=1)
        edge = (domain, slot)
        G.add_edge(*edge)
        edge_clusters[edge] = cluster_label
    
    # Count number of each cluster_label
    cluster_counts = {}
    for label in edge_clusters.values():
        cluster_counts[label] = cluster_counts.get(label, 0) + 1

    # Print statistical results
    print("Clustering label statistics:")
    for label, count in sorted(cluster_counts.items()):
        print(f"Clustering label {label}: {count} edges")

    # Define node positions
    pos = {}
    pos.update((n, (1, i*len(slots)//len(domains))) for i, n in enumerate(domains))
    pos.update((n, (2, i)) for i, n in enumerate(slots))

    # Define color mapping
    unique_clusters = sorted(set(edge_clusters.values()))
    num_clusters = len(unique_clusters)
    cmap = plt.get_cmap('coolwarm', num_clusters)

    # Get corresponding color for each edge
    edge_colors = []
    for edge in G.edges():
        if edge in edge_clusters:
            cluster = edge_clusters[edge]
        elif edge[::-1] in edge_clusters:
            cluster = edge_clusters[edge[::-1]]
        edge_colors.append(cmap(unique_clusters.index(cluster)))

    # Draw bipartite graph and save returned figure object
    plt.figure(figsize=(12, 12))
    ax = plt.gca()  # Get current axes
    nx.draw_networkx(G, pos=pos, node_color='lightblue', edge_color=edge_colors, ax=ax)

    # Create colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=num_clusters - 1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ticks=np.arange(num_clusters), ax=ax)
    cbar.set_label('Cluster Label')

    plt.title('Domain-Slot Spectral Clustering')
    plt.axis('off')

    # Embedding visualization
    X = slot_embeddings.cpu().numpy()
    # Extract domain labels and encode
    domains = [item['cluster_label'] for item in clusters]
    le = LabelEncoder()
    encoded_labels = le.fit_transform(domains)
    
    # Use PCA for dimensionality reduction
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X)
    
    # Use t-SNE for dimensionality reduction
    tsne = TSNE(n_components=2, perplexity=5)
    X_tsne = tsne.fit_transform(X)

    # Use t-SNE for dimensionality reduction
    tsne1 = TSNE(n_components=2, perplexity=25)
    X_tsne1 = tsne.fit_transform(X)
    
    # Visualize
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.scatter(X_pca[:, 0], X_pca[:, 1], c=encoded_labels, cmap='coolwarm')
    plt.title('PCA Visualization (Colored by Cluster)')
    
    plt.subplot(1, 3, 2)
    scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=encoded_labels, cmap='coolwarm')
    plt.title('t-SNE Visualization (perplexity=5)')

    plt.subplot(1, 3, 3)
    scatter = plt.scatter(X_tsne1[:, 0], X_tsne1[:, 1], c=encoded_labels, cmap='coolwarm')
    plt.title('t-SNE Visualization (perplexity=25)')

    # Add colorbar and set labels
    plt.colorbar(scatter, ticks=range(len(le.classes_)), label='cluster')
    # Set colorbar tick labels
    plt.gca().get_legend().set_title('cluster') if plt.gca().get_legend() else None
    # Display domain name correspondence
    domain_labels = [f'{i}: {domain}' for i, domain in enumerate(le.classes_)]
    plt.figtext(0.5, 0.01, ' | '.join(domain_labels), ha='center', fontsize=8)
    
    plt.show()