import os
import json
import os
import shutil
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

def load_json(filename):
    with open(filename, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


def save_json(filename, data):
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def read_lines(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return [e.strip("\n") for e in f.readlines()]
    except UnicodeDecodeError:
        try:
            with open(filepath, 'r', encoding='gbk') as f:
                return [e.strip("\n") for e in f.readlines()]
        except Exception as e:
            return None


def save_lines(filepath, data):
    with open(filepath, "w") as f:
        f.write("\n".join(data))


def mkdirp(p):
    if not os.path.exists(p):
        os.makedirs(p)


def deletedir(p):
    if os.path.exists(p):
        shutil.rmtree(p)

def fileExist(p):
    if os.path.exists(p):
        return True
    else:
        return False

def l2_normalize(x, dim, eps):
    return x / x.norm(p=2, dim=dim, keepdim=True).clamp_min(eps)


def spherical_kmeans_init(X, K, iters, seed):

    g = torch.Generator(device=X.device)
    g.manual_seed(seed)

    N, d = X.shape
    idx = torch.randperm(N, generator=g, device=X.device)[:K]
    mu = X[idx].clone() 

    for _ in range(iters):
        sim = X @ mu.t() 
        assign = sim.argmax(dim=1)

        # update centers
        new_mu = torch.zeros_like(mu)
        for k in range(K):
            mask = (assign == k)
            if mask.any():
                new_mu[k] = X[mask].mean(dim=0)
            else:
                # re-sample empty cluster
                new_mu[k] = X[torch.randint(0, N, (1,), generator=g, device=X.device)].squeeze(0)
        mu = l2_normalize(new_mu, dim=-1)
    return mu

def cosine_distance(a, b, eps):
    a = l2_normalize(a, dim=-1, eps=eps)
    b = l2_normalize(b, dim=-1, eps=eps)
    return 1.0 - (a * b).sum(dim=-1)

