import warnings
warnings.filterwarnings("ignore")

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
from utils import CustomDataset
from torch.utils.data import DataLoader, Dataset

from scipy import stats
from model import estimate_alpha
from smm import SMM
from gof import GOFScorer


class GalaxyBase_SMM():
    def __init__(self, 
                 rep_model, 
                 k=10, 
                 seed=42, 
                 device='cuda', 
                 version='scalar',
                 l=0.01,
                 iters=3,
                 recon_term=True,
                 abla_clust=None,
                 score_type='scalar',
                 ):
        self.k = k
        self.seed = seed
        self.device = device
        self.version = version
        self.iters = iters
        self.l = l
        self.recon_term = recon_term
        self.abla_clust = abla_clust
        self.score_type = score_type
        
        if abla_clust is None:
            self.smm = SMM(
                    n_components=self.k,
                    covariance_type='diag',
                    n_iter=100,
                    tol=1e-3,
                    random_state=self.seed,
                    params='wmc', 
                    init_params='wmcd'
                )
            
        if rep_model is not None:
            self.rep_model = rep_model.to(self.device)
            
        self.scorer = GOFScorer(device=self.device)
    
    def fit(self, X):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).to(torch.float32).to(self.device)
        
        self.update_prototypes(X)
        
        for iter in range(self.iters):
            X_i = self._exclude_outlier_set(X)
            
            assert not torch.isnan(X_i).any()
            
            self.update_network(X_i)
            self.update_prototypes(X_i)
        
        return self
    
    def _exclude_outlier_set(self, X):
        Z = self.rep_model.encoder(X)
        if self.score_type in ['scalar', 'vector']:
            score = self.scorer.get_score(Z, self.means, score_type=self.score_type)
        # elif self.score_type in ['scalar-inplace']:
        #     score = 1. / self.smm.score(X).cpu().detach().numpy()
        threshold = np.percentile(score, 100*(1-self.l))
        X_i = X[score <= threshold]
        return X_i
            
    def update_prototypes(self, X):
        Z = self.rep_model.encoder(X)
        self.smm.fit(Z.cpu().detach().numpy())
        
        if self.abla_clust is None:
            self.means = torch.from_numpy(self.smm.means_).to(torch.float32)
            self.weights = torch.from_numpy(self.smm.weights_).to(torch.float32)
            self.covars = torch.from_numpy(self.smm.covars_).to(torch.float32)
        
            self.means = self.means.to(self.device)
            self.weights = self.weights.to(self.device)
            self.covars = self.covars.to(self.device)
            
            assert not torch.isnan(self.means).any()
            assert not torch.isnan(self.weights).any()
            assert not torch.isnan(self.covars).any()
        
        
    def update_network(self, X):
        optimizer = torch.optim.Adam(self.rep_model.parameters(), lr=3e-4)
            
        for _ in range(100):
            self.rep_model.train()
            
            embed, x_hat = self.rep_model(X)
            if self.recon_term:
                loss = torch.nn.functional.mse_loss(x_hat, X, reduction='sum')
            else:
                loss = 0
            
            det_covars = torch.from_numpy(np.prod(self.smm.covars_, axis=1)).to(self.device) #torch.prod(self.covars, dim=1)
            maha = self._mahalanobis_distance_mix_diag(embed, self.means, self.covars)
            braces = 1.0 + maha 
            demon = torch.pi * torch.sqrt(det_covars) * braces
            forces = self.weights / demon
            
            if self.version == 'scalar':
                scalar_loss = -forces.sum(dim=1).log().sum()
                assert not torch.isinf(scalar_loss).any()
                loss += scalar_loss
            elif self.version == 'vector':
                unit_vec = F.normalize(
                    self.means.unsqueeze(0) - embed.unsqueeze(1), 
                    p=2, dim=-1) # (n, k, d)
                force_vec = (forces.unsqueeze(2) * unit_vec).sum(dim=1)
                vector_loss = -torch.norm(force_vec, dim=-1).log().sum()
                loss += vector_loss
            else:
                assert False
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    def _mahalanobis_distance_mix_diag(self, X, means, covars):
        n_samples, n_dim = X.shape
        n_components = means.shape[0]
        result = torch.empty((n_samples, n_components)).to(self.device)

        for c in range(n_components):
            mu = means[c]
            cv = covars[c]
            centred_X = X - mu
            inv_cov = 1.0 / cv
            result[:, c] = (centred_X * inv_cov * centred_X).sum(dim=1)

        return result