import numpy as np
import torch
from .base import BaseDivergence


class BregmanHardClustering:
    def __init__(self, X, div_fn, n_clusters):
        self.X = X.float()
        self.n_clusters = n_clusters
        self.div_fn = div_fn
        self.labels_ = None
        self.cluster_centers_ = None

    def fit(self, n_iter=100, conv_eps=1e-4):
        mu = self._initialize()

        for step in range(n_iter):
            mu_prev = mu.clone()

            clust_h = self._assign_step(mu)
            mu = self._reestimate_step(clust_h, mu)

            if torch.abs(mu - mu_prev).sum() < conv_eps:
                print(f'converged at step {step}')
                break

        self.labels_ = clust_h
        self.cluster_centers_ = mu

    def _initialize(self):
        p = self.X.shape[1]
        perm = torch.randperm(self.X.shape[0])[:self.n_clusters]
        mu = self.X[perm][:self.n_clusters]
        # bounds = (torch.min(self.X, dim=0)[0], torch.max(self.X, dim=0)[0])
        # unif = torch.rand(self.n_clusters, p, device=self.X.device)
        # mu = (bounds[1] - bounds[0]).view(1, p) * unif + bounds[0].view(1, p)
        return mu

    def _assign_step(self, mu):
        n = self.X.shape[0]
        dist_mat = torch.zeros(n, self.n_clusters, device=self.X.device)
        for h in range(self.n_clusters):
            dist_mat[:, h] = self.div_fn(self.X, mu[h].repeat(n, 1))
        return torch.argmin(dist_mat, dim=1)

    def _reestimate_step(self, clust_h, mu):
        p = self.X.shape[1]
        mu = torch.zeros(self.n_clusters, p, device=self.X.device)
        for h in range(self.n_clusters):
            in_clust_h = clust_h == h
            mu[h] = self.X[in_clust_h].mean(dim=0)
        return mu


class BregmanSoftClustering:
    def __init__(self, X, div_fn, n_clusters):
        self.X = X.float()
        self.d_phi = d_phi
        self.n_clusters = n_clusters
        self.div_fn = div_fn
        self.weights_ = None
        self.means_ = None

    def fit(self, X, n_iter=1000, conv_eps=1e-4):
        mu, pi = self._initialize()

        for step in range(n_iter):
            mu_prev, pi_prev = mu.clone(), pi.clone()
            
            cond_prob_k = self._E_step(mu, pi)
            mu, pi = self._M_step(cond_prob_k, mu, pi)

            if torch.abs(mu - mu_prev).sum() + torch.abs(pi - pi_prev).sum() < conv_eps:
                print(f'converged at step {step}')
                break

        self.weights_ = pi
        self.means_ = mu
        self.soft_labels_ = cond_prob_k

    def _initialize(self):
        p = self.X.shape[1]
        bounds = (torch.min(self.X, dim=0)[0], torch.max(self.X, dim=0)[0])
        mu = np.random.uniform(
            low=bounds[0].numpy(), high=bounds[1].numpy(), size=(self.n_clusters, p))
        pi = torch.ones(self.n_clusters) / float(self.n_clusters)
        return torch.tensor(mu), pi

    def _E_step(self, mu, pi):
        n = self.X.shape[0]
        cond_probs = torch.zeros(n, self.n_clusters)
        for h in range(self.n_clusters):
            cond_probs[:, h] = torch.exp(-self.div_fn(self.X, mu[h].repeat(n, 1)))
        return cond_probs / cond_probs.sum(dim=1, keepdim=True)

    def _M_step(self, cond_prob_k, mu, pi):
        p = self.X.shape[1]
        pi = cond_prob_k.mean(dim=0)
        mu = torch.zeros(self.n_clusters, p)

        for h in range(self.n_clusters):
            weighted_sum = (cond_prob_k[:, h:h+1] * self.X).sum(dim=0)
            mu[h, :] = weighted_sum / cond_prob_k[:, h].sum()
        return mu, pi
