import numpy as np
from sklearn import linear_model
import pickle, time, os


class SupervisedKSVD:
    def __init__(self, n_components=220, max_iter=10, tol=0.001,
                 atom_for_cls_=None, rho=None):
        self.max_iter = max_iter
        self.tol = tol
        self.n_components = n_components
        self.atom_for_cls_ = atom_for_cls_

        self.iteration_log = {}

        self.rho = rho
    
    def _initialize(self, y):
        self.dictionary, self.atom_for_cls = self.ksvd_cls_initialize(y, self.label, rho=self.rho)
        self.idxrange_for_cls_ = np.cumsum(self.atom_for_cls)
    
    def fit(self, y, label, save=False, save_name=None):
        self.label = label
        self.num_cls = label.shape[1]

        y = y.T
        # target, (num_features, num_samples)
        # x, sparse coefficient, (num_atoms, num_samples)
        self._initialize(y)
        last_e = np.inf
        for i in range(self.max_iter):
            x = self._update_coeff(y)
            e = np.linalg.norm(y - np.dot(self.dictionary, x))

            self.iteration_log.setdefault('error', []).append(e)
            print(f'Iter {i}, error {e}. \n' + '*' * 5)
            if last_e > e and last_e - e < self.tol:
                break

            last_e = e

            self._update_dict(y, self.dictionary, x)

        self.idxrange_for_cls = np.stack([self.idxrange_for_cls_[:-1], self.idxrange_for_cls_[1:]], axis=1)

        if save:
            del self.label

            if save_name is not None:
                self.save(save_name)
            else:
                if not os.path.exists('./model'):
                    os.makedirs('./model')
                self.save(f'./model/{self.__class__.__name__}_{time.time()}.pkl')
        return x

    def _update_coeff(self, y, label=None):
        # In case call this func in self.predict()
        if label is None:
            label = self.label

        # Either regress the coefficient sample by sample,
        # or allocate all samples with the same label pattern and
        # regress pattern by pattern
        do_unique_pattern = False
        unique_pattern = np.unique(label, axis=0)
        if len(unique_pattern) < y.shape[1]:
            do_unique_pattern = True
        
        x = np.zeros((self.dictionary.shape[1], y.shape[1]))

        if do_unique_pattern:
            for p in unique_pattern:
                p_idx = np.where((label == p).all(axis=1))[0]
                self._coeff_by_sample(y[:, p_idx], p, x, p_idx)
        
        else:
            for j in range(y.shape[1]):
                self._coeff_by_sample(y[:, [j]], label[j], x, [j])
        
        return x
    
    def _coeff_by_sample(self, sample, label, coeff_, index_for_sample):
        # Number of atoms changes with iteration adaptively
        num_atoms = self.dictionary.shape[1]
        common_part =  num_atoms - self.idxrange_for_cls_[-1]

        if common_part > 0:
            idx_common_part = np.arange(self.idxrange_for_cls_[-1], num_atoms)
        
        index_for_atom = [np.arange(self.idxrange_for_cls_[k],
                                    self.idxrange_for_cls_[k+1])
                          for k in range(self.num_cls) if label[k] == 1]
        index_for_atom = np.concatenate(index_for_atom) if len(index_for_atom) > 0 else []
        # Add common part
        if common_part > 0:
            index_for_atom = np.concatenate([index_for_atom, idx_common_part]).astype(int)
        
        reg = linear_model.LinearRegression(positive=True,
                                            fit_intercept=False).fit(self.dictionary[:, index_for_atom], sample)
        coeff_[np.ix_(index_for_atom, index_for_sample)] = reg.coef_.T

    def _update_dict(self, y, d, x):
        common_part = d.shape[1] - np.cumsum(self.atom_for_cls)[-1]
        for i in range(self.num_cls + 1):
            if i == self.num_cls:
                j = -np.arange(1, common_part + 1)[::-1]
            else:
                j = np.arange(self.idxrange_for_cls_[i], self.idxrange_for_cls_[i+1])
            if len(j) == 0:
                continue

            d, x = self._update_atom(x, y, d, i, j)

        self.dictionary = d
        print(f'Size {self.atom_for_cls}, total {d.shape[1]}.')
    
    def _update_atom(self, x, y, d, i, j):
        if i == self.num_cls:
            index = np.arange(self.label.shape[0])
        else:
            index = np.where(self.label[:, i] == 1)[0]
        if len(index) == 0:
            d[:, j] = 0
        else:
            for atom_idx in j:

                new_index = np.nonzero(x[atom_idx, :])[0]
                if len(new_index) == 0:
                    continue

                d[:, atom_idx] = 0
                r = (y - np.dot(d, x))[:, new_index] # (num_features, num_pos_samples)

                u, s, v = np.linalg.svd(r, full_matrices=False)

                k = 1
                
                coef_ = np.diag(s[:k]) @ v[:k]
                pos_norm = np.linalg.norm(coef_ * (coef_ > 0), axis=1) # k
                neg_norm = np.linalg.norm(coef_ * (coef_ < 0), axis=1)
                sign = (pos_norm > neg_norm).astype(float) * 2 - 1

                d[:, atom_idx] = (u[:, :k] * sign)[:, 0]

                reg = linear_model.LinearRegression(fit_intercept=False,
                                                    positive=True).fit(u[:, :k] * sign, r)
                
                x[atom_idx, new_index] = reg.coef_.T

                # u, s, v = topsing(r, maxiter=50)
                # d[:, atom_idx] = u
                # x[atom_idx, new_index] = (s * v).reshape(1, -1)

        return d, x
    
    def predict(self, y, designate_concept=None, regularized=False):
        x, active_cls = self._update_coeff_lasso_sep(y.T, designate_concept=designate_concept, regularized=regularized)
        self.label = active_cls[:,:].astype(int)
        return x.T
    
    def _update_coeff_lasso_sep(self, test_sample, alpha=1e-8, designate_concept=None, regularized=False):
        # test sample shape (num_samples, num_features)
        all_coeff_ = []
        all_active_group = []
        for sample in test_sample:
            if designate_concept is None:
                alpha = 1e-4
                stop = False
                while not stop:
                    print(alpha)
                    coeff_ = linear_model.Lasso(alpha=alpha, positive=True,
                                                fit_intercept=False).fit(self.dictionary, sample).coef_ # (num_atoms,)
                    if np.sum(coeff_ != 0) > int(self.dictionary.shape[1] * 0.3):
                        stop = True
                    else:
                        alpha = alpha * 0.8

                cls_norm = np.array([np.linalg.norm(coeff_[r[0]: r[1]]) for r in self.idxrange_for_cls])

                active_group = np.zeros(self.idxrange_for_cls.shape[0])
                active_group[np.argsort(cls_norm)[::-1][:3]] = 1
            
            else:
                active_group = designate_concept

            coeff_ = np.zeros(self.dictionary.shape[1])
            for i, c in enumerate(active_group):
                if c == 0:
                    continue
                
                atom_idx = np.arange(self.idxrange_for_cls[i][0], self.idxrange_for_cls[i][1])
                if not regularized:
                    coef_ = linear_model.LinearRegression(positive=True,
                                                          fit_intercept=False).fit(self.dictionary[:, atom_idx], sample).coef_
                else:
                    coef_ = linear_model.Ridge(positive=True, alpha=0.1,
                                               fit_intercept=False).fit(self.dictionary[:, atom_idx], sample).coef_
                coeff_[atom_idx] = coef_

            all_coeff_.append(coeff_)
            all_active_group.append(active_group)

        all_active_group = np.stack(all_active_group, axis=0)       
        all_coeff_ = np.stack(all_coeff_, axis=0) # (num_samples, num_atoms)

        return all_coeff_, all_active_group
    
    def save(self, save_file):
        with open(save_file, 'wb') as file:
            pickle.dump(self.__dict__, file)
    
    def load(self, load_file):
        with open(load_file, 'rb') as file:
            self.__dict__ = pickle.load(file)
    
    def ksvd_cls_initialize(self, samples, labels, rho=None):
        afc = [0]
        d = []
        for i in range(labels.shape[1]):
            sample_cls = samples[:, labels[:, i] == 1]
            u, s, v = np.linalg.svd(sample_cls, full_matrices=False)
            if self.atom_for_cls_ is not None:
                k = self.atom_for_cls_
            else:
                singv_thresh = rho[i] * np.sum(s)
                # Order is descending
                k = np.where(np.cumsum(s) > singv_thresh)[0][0]
            afc.append(k)

            coef_ = np.diag(s[:k]) @ v[:k]
            pos_norm = np.linalg.norm(coef_ * (coef_ > 0), axis=1) # k
            neg_norm = np.linalg.norm(coef_ * (coef_ < 0), axis=1)
            sign = (pos_norm > neg_norm).astype(float) * 2 - 1

            d.append(u[:, :k] * sign)
        
        afc = np.array(afc)
        d = np.concatenate(d, axis=1)

        # if common part
        common_part = self.n_components - np.sum(afc)
        if common_part > 0:
            # cp = np.random.normal(0, 1, (samples.shape[0], common_part))
            # cp /= np.linalg.norm(cp, axis=0, keepdims=True)
            u, s, v = np.linalg.svd(samples, full_matrices=False)

            coef_ = np.diag(s[:common_part]) @ v[:common_part]
            pos_norm = np.linalg.norm(coef_ * (coef_ > 0), axis=1)
            neg_norm = np.linalg.norm(coef_ * (coef_ < 0), axis=1)
            sign = (pos_norm > neg_norm).astype(float) * 2 - 1

            cp = u[:, :common_part] * sign
            d = np.concatenate((d, cp), axis=1)
        return d, afc

    def fit_minibatch(self, y, label, bs=2000, epoch=10, save=False, save_name=None):
        self.num_cls = label.shape[1]
        self.label = label
        sample_idx = np.arange(y.shape[0])

        y = y.T
        # target, (num_features, num_samples)
        # x, sparse coefficient, (num_atoms, num_samples)
        self._initialize(y)
        last_e = np.inf
        for i in range(epoch):
            np.random.shuffle(sample_idx)
            y = y[:, sample_idx]
            label = label[sample_idx]

            y_batches = np.array_split(y, y.shape[1] // bs, axis=1)
            label_batches = np.array_split(label, y.shape[1] // bs, axis=0)

            for batch_y, batch_label in zip(y_batches, label_batches):
                self.label = batch_label
                x = self._update_coeff(batch_y)
                e = np.linalg.norm(batch_y - np.dot(self.dictionary, x))

                self.iteration_log.setdefault('error', []).append(e)
                print(f'Iter {i}, error {e}. \n' + '*' * 5)
                if last_e > e and last_e - e < self.tol:
                    break

                last_e = e

                self._update_dict(batch_y, self.dictionary, x)

        self.idxrange_for_cls = np.stack([self.idxrange_for_cls_[:-1], self.idxrange_for_cls_[1:]], axis=1)

        if save:
            del self.label

            if save_name is not None:
                self.save(save_name)
            else:
                if not os.path.exists('./model'):
                    os.makedirs('./model')
                self.save(f'./model/{self.__class__.__name__}_{time.time()}.pkl')
        return x
