from copy import deepcopy
import pickle
import time
import numpy as np
import tensorly as tl
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.decomposition import NMF
from tensorly import unfolding_dot_khatri_rao
from tqdm import trange
from termcolor import cprint
from ncp import ncp
from utils import compute_error, compute_model_cost, compute_coding_cost



class ASSMF:
    def __init__(self, period, n_components, max_regimes=100, zero=1e-12,
                 alpha=0.1, beta=0.1, init_cycles=3, update_freq=0, compression=False):
        
        # dimensionality
        self.s = s = period
        self.k = k = n_components
        self.r = r = max_regimes
        self.m = 1  # of regimes

        # Learning setting
        self.zero = zero
        self.alpha = alpha
        self.beta  = beta
        self.init_cycles = init_cycles
        self.update_freq = update_freq
        self.compression = compression

    def initialize(self, X):

        k = self.k
        r = self.r
        s = self.s
        init_cycles = self.init_cycles

        self.d = d = X.shape[:-1]  # dims except for the time axis
        self.n = n = X.shape[-1]  # maximum length you will observe

        # Components
        self.U = [np.zeros((di, k)) for di in d]  # U and V
        self.W = np.zeros((r, s + n, k))  # W(1), ..., W(r)

        # Outputs
        self.E = np.zeros((r, n))  # history of errors
        self.R = np.zeros(n, dtype=int)  # history of regime indices
        self.O = np.zeros(n, dtype=int)  # history of operations

        # Initialize regime history
        self.R[:s] = 0

        # Initialize parameters
        X_fold = np.array([X[..., i*s:(i + 1)*s] for i in range(init_cycles)])
        X_fold = X_fold.sum(axis=0) / init_cycles
        factors = ncp(X_fold, self.k, maxit=3)
        self.W[:, :s] = factors[-1]

        # Normalize
        for i in range(X.ndim - 1):
            weights = np.sqrt(np.sum(factors[i] ** 2, axis=0))
            self.U[i] = factors[i] @ np.diag(1 / weights)
            self.W[:, :s] = self.W[:, :s] @ np.diag(weights)  # k x k

    def apply_grad(self, At, t, ridx):

        n_mode = len(self.U)
        U0, U1 = self.U
        D = np.diag(self.W[ridx, t])

        grad = [
            At @ U1 @ D - U0 @ D @ (U1.T @ U1) @ D,
            At.T @ U0 @ D - U1 @ D @ (U0.T @ U0) @ D
        ]

        Wfit = np.copy(self.W[ridx, t])

        for i in range(n_mode):
            # Gradient descent
            val = self.alpha * np.sqrt(self.k) / np.sqrt(np.sum(grad[i] ** 2))
            grad[i] *= min(1, val)
            self.U[i] += self.alpha * grad[i]
            # Normalize
            weights = np.sqrt(np.sum(self.U[i] ** 2, axis=0))
            self.U[i] = self.U[i] @ np.diag((1 / weights))
            self.U[i] = self.U[i].clip(min=0)  # max(U, 0)
            # Update Seasonal weights
            Wfit = Wfit * weights

        # self.W[:, t + self.s] = self.W[:, t]
        # self.W[ridx, t + self.s] = Wfit
        self.W[ridx, t] = Wfit

    def fit(self, X, t):

        n = X.shape[-1]
        elapsed_time = np.zeros(X.shape[-1] - self.s)
        print(n, t, X.shape)

        for tt in range(self.s, n):
            tc = t + tt
            print("tc", tc)
            tic = time.process_time()
            self.update_v2(X[..., tt-self.s:tt], tc)
            elapsed_time[tt - self.s] = time.process_time() - tic

        return elapsed_time

    def search_regime(self, X, t):
        """
        """
        Xc = X[..., -1]
        U, V = self.U
        E = [None] * self.m
        
        n = X.shape[-1]
        Y = np.zeros(X.shape)
        for i in trange(self.m, desc="SEARCH"):
            for tt in range(n):
                Wt = self.W[i, t - n + 1 + tt]
                Y[..., tt] = U @ np.diag(Wt) @ V.T

            E[i] = compute_coding_cost(X, Y)
            
        best_idx = np.argmin(E)
        coding_cost = E[best_idx]
        model_cost  = compute_model_cost(self.W[best_idx, t - self.s + 1:t + 1])
        model_cost += compute_model_cost(self.U[0])
        model_cost += compute_model_cost(self.U[1])

        return coding_cost + model_cost, best_idx

    def estimate_regime(self, X, t, ridx):
        print("Estimation")
        U, V = self.U

        # Parameter estimation
        # """
        W = deepcopy(self.W[ridx, t - self.s + 1:t + 1])  # must include t
        P = [deepcopy(U), deepcopy(V)]

        # ---- parafac base ---- #
        # print("Wold", W.shape)
        # accums = (U.T @ U) * (V.T @ V)
        # mttkrp = unfolding_dot_khatri_rao(X, (None, [U, V, W]), 2)
        # num = mttkrp.clip(min=self.zero)
        # den = np.dot(W, accums).clip(min=self.zero)
        # # print("num", num.shape)
        # # print("den", den.shape)
        # W = W * num / den

        # Update U and V with X in the last season
        for i in range(X.shape[-1]):
            At = X[..., i]
            U0, U1 = P
            D = np.diag(W[i])

            grad = [
                At @ U1 @ D - U0 @ D @ (U1.T @ U1) @ D,
                At.T @ U0 @ D - U1 @ D @ (U0.T @ U0) @ D
            ]

            for i in range(2):
                # Gradient descent
                val = self.alpha * np.sqrt(self.k) / np.sqrt(np.sum(grad[i] ** 2))
                grad[i] *= min(1, val)
                self.U[i] += self.alpha * grad[i]
                # Normalize
                weights = np.sqrt(np.sum(self.U[i] ** 2, axis=0))
                self.U[i] = self.U[i] @ np.diag((1 / weights))
                self.U[i] = self.U[i].clip(min=0)  # max(U, 0)
                # Update Seasonal weights
                W = W * weights

        # """

        # ---- ncp base ----#
        # print("ncp...")
        # A = ncp(X, self.k, maxit=10)
        # W = A[-1]
        # for i in range(2):
        #     weights = np.sqrt(np.sum(A[i] ** 2, axis=0))
        #     # self.U[i] = A[i] @ np.diag(1 / weights)
        #     W = W @ np.diag(weights)  # k x k

        # self.W[self.m, t - self.s + 1 : t + 1] = W
        # P = A[:-1]
        # print(U[0])
        # print(P[0][0])
        # print(self.W[0, t-self.s+1:t+1, 0])
        # print(W[:, 0])
        # raise KeyboardInterrupt

        # cost evaluation
        # Xc = X[..., -1]
        # Yc = U @ np.diag(W[-1]) @ V.T
        Y = np.zeros(X.shape)
        for tt in range(X.shape[-1]):
            Wt = np.diag(W[tt])
            Y[..., tt] = P[0] @ Wt @ P[1].T

        # Yc = P[0] @ np.diag(W[-1]) @ P[1].T
        # coding_cost = compute_coding_cost(Xc, Yc)
        coding_cost = compute_coding_cost(X, Y)
        print("Estimation error=", compute_error(X, Y))
        model_cost = compute_model_cost(W)
        model_cost += compute_model_cost(P[0])
        model_cost += compute_model_cost(P[1])

        return coding_cost + model_cost, self.m, P

    def compress_regime(self, X, t):
        return

    def update_v2(self, X, t):
        """
        """
        print("input:", X.shape)
        cost1 = cost2 = cost3 = np.inf
        P = None
        self.W[:, t] = self.W[:, t - self.s]

        # 1: keep
        cost1, ridx1 = self.search_regime(X, t)

        # 2: increase
        if t % self.update_freq == 0 and not t == self.s:
            cost2, ridx2, P = self.estimate_regime(X, t, ridx1)

            # 3: decrease
            if self.compression:
                cost3, ridx3, Wtmp = self.compress_regime(X, t)

        bias = self.beta * cost1
        costs = [cost1, bias + cost2, bias + cost3]
        print("costs=", costs)
        self.O[t] = case = np.argmin(costs)

        if case == 0:
            cprint(">>>> KEEP", "blue")
            self.R[t] = ridx1

        elif case == 1:
            cprint(">>>> ADD", "green")
            self.U = P
            self.R[t - self.s + 1:t + 1] = ridx2
            self.m += 1
            # raise KeyboardInterrupt

        elif case == 2:
            cprint(">>>> DELETE", "red")
            self.R[t] = ridx3
            self.m = len(Wtmp)  # overwrite
            self.W[:self.m, t - self.s + 1: t + 1] = Wtmp

        # print(self.R[:t+1])
        print("R[{}] = {} is updated".format(t, self.R[t]))
        self.apply_grad(X[..., -1], t, self.R[t])

        # raise KeyboardInterrupt
        print(np.unique(self.R))
        return

    def update(self, X, t, verbose=False):
        """ Main algorithm to manage regimes,
            called for each time point.
            When the current data Xt arrives, this algorithm:        
            1. Choose the best regime in current regimes
            2. Try to compress current regimes and add a new regime
            3. Compare the two result to decide whether to employ the regime
            4. If the regime is not employed then update a existing regime
        """
        # if X[-1].sum() == 0:
        #     print(">>>> SKIP")
        #     self.R[t] = self.R[t-1]
        #     self.O[t] = -1
        #     self.W[:, t + self.s] = self.W[:, t]
        #     return

        cost1 = cost2 = cost3 = cost4 = np.inf

        # 1. Choose the best regime from existing regimes
        # Given: candidate regimes and current data
        # Find: best regime index and MDL score when predicting current data
        r = self.r  # maximum number of regimes
        m = self.m  # current number of regimes
        Xt = X[..., -1]  # latest data sample

        # Case 1: selecting the best regime in an existing set

        # only current sample
        for i in trange(m, desc="search"):
            # Make prediction
            Yt = self.U[0] @ np.diag(self.W[i, t]) @ self.U[1].T
            # Evaluate errors
            # self.E[i, t] = compute_error(Xt, Yt)
            self.E[i, t] = compute_coding_cost(Xt, Yt)

        best_ridx_in_existing = np.argmin(self.E[:m, t])
        coding_cost = self.E[best_ridx_in_existing, t]

        # last season
        # Etmp = []
        # for i in trange(m, desc="search"):
        #     self.R[t] = i  # tmp
        #     Y = self.predict_seq(t, self.s)
        #     Etmp.append(compute_coding_cost(X, Y))
        # best_ridx_in_existing = np.argmin(Etmp)
        # coding_cost = Etmp[best_ridx_in_existing]

        cost1 = coding_cost
        print("coding_cost1=", coding_cost)
        print("model_cost1=", 0)

        # Case 2: generating a new regime

        if self.update_freq > 0 and t % self.update_freq == 0:
            print("TIME TO UPDATE")
            West = fit_seasonal_weights(X, *self.U, self.W[best_ridx_in_existing, t-self.s:t], zero=self.zero)
            Yte = self.U[0] @ np.diag(West[-1]) @ self.U[1].T
            coding_cost_new = compute_coding_cost(Xt, Yte)

            # Y = np.zeros(X.shape)
            # for i in range(len(West)):
            #     Y[..., i] = self.U[0] @ np.diag(West[i]) @ self.U[1].T
            # coding_cost_new = compute_coding_cost(X, Y)

            model_cost_new = compute_model_cost(West)
            cost2 = coding_cost_new + model_cost_new
            print("coding_cost2=", coding_cost_new)
            print("model_cost2=", model_cost_new)

        # Case 3: compressing existing regimes
        print(self.compression)
        if self.compression and t % self.update_freq == 0:
            if m > 1:
                Wcmp, model_cost_comp = regime_compression(
                    self.W[:m, t-self.s:t])
                print(Wcmp.shape)

                # Search best regimes in compressed regimes
                mc = len(Wcmp)
                Ecmp = [None] * mc
                for i in range(mc):
                    Yt = self.U[0] @ np.diag(Wcmp[i, -1]) @ self.U[1].T
                    # Evaluate errors
                    Ecmp[i] = compute_coding_cost(Xt, Yt)

                best_comp_regime = np.argmin(Ecmp)
                coding_cost_comp = Ecmp[best_comp_regime]
                cost3 = coding_cost_comp + model_cost_comp
                print("cosding_cost3=", coding_cost_comp)
                print("model_cost3=", model_cost_comp)

                cost4 = cost2 + model_cost_comp

            else:
                Wcmp, model_cost_comp = None, np.inf

        # Model structure selection
        result = [
            cost1, # coding cost
            cost2, # coding cost + model cost
            cost3, # coding cost + saving cost
            cost4  # coding cost + model cost + saving cost
        ]
        print(result)
        self.O[t] = case = np.argmin(result)

        if case == 0:
            cprint(">>>> KEEP", "blue")
            self.R[t] = best_ridx_in_existing

        elif case == 1:
            cprint(">>>> ADD", "red")
            self.R[t] = m
            self.W[m, t - self.s:t] = West
            self.W[m, t:t + self.s] = West
            # plt.plot(West)
            # plt.savefig("West.png")
            # plt.close()
            self.m += 1

        elif case == 2:
            cprint(">>>> COMPRESS", "green")
            # time.sleep(2)
            self.R[t] = best_comp_regime
            self.m = len(Wcmp)
            self.W[:self.m, t-self.s:t] = Wcmp
            raise KeyboardInterrupt
            
        elif case == 3:
            cprint(">>>> REARRANGE", "yellow")
            self.m = len(Wcmp) + 1
            self.R[t] = self.m - 1
            self.W[:self.m, t-self.s:t] = Wcmp
            self.W[self.m, t-self.s:t] = West
            time.sleep(2)

        self.W[:, t + self.s] = self.W[:, t]
        cprint(" Number of regimes= {} ".format(self.m),
                "grey", "on_white")

        # Update U, V with the selected W
        ridx = self.R[t]

        # for i in range(self.m):
        self.apply_grad(Xt, t, ridx)

        return

    def predict_seq(self, current_time, forecast_step):
        """
        """
        U, V = self.U
        pred = np.zeros((U.shape[0], V.shape[0], forecast_step))
        tcur = current_time

        # r_idx = self.R[current_time - 1]  # the latest regime
        # r_idx = stats.mode(self.R[tcur - self.s:tcur])[0][0]
        idxs, counts = np.unique(self.R[tcur - self.s:tcur], return_counts=True)
        for i, c in zip(idxs, counts):
            print(i, "=", c)
        r_idx = idxs[np.argmax(counts)]
        print("Prediction with regime {}".format(r_idx))

        for t in trange(forecast_step, desc="forecast"):
            tseas = tcur - self.s + np.mod(t, self.s)
            W = self.W[r_idx, tseas]    
            pred[..., t] = U @ np.diag(W) @ V.T

        return pred

    def predict(self, current_time, forecast_step):
        t_seas = current_time + self.s - np.mod(
            current_time + forecast_step, self.s)

        r_idx = self.R[current_time]

        return self.U[0] @ np.diag(self.W[r_idx, t_seas]) @ self.U[1].T

    def compute_anom_score(self, At, t):
        Wfit = self.W[t + self.s]
        D = np.diag(Wfit)

        WVVW = D @ self.U[1].T @ self.U[1] @ D
        
        # for zero elements
        for i in range(At.shape[0]):
            ui = self.U[0][i]
            self.E[i, t] = ui @ WVVW @ ui.T

        # for non-zero elements
        for r, c in zip(*np.nonzero(At)):
            Aobs = At[r, c]
            Ahat = sum(self.U[0][r] * Wfit * self.U[1][c])
            # print(Aobs, Ahat)
            self.E[r, t] += (Aobs - Ahat) ** 2 - Ahat ** 2

    def set_params(self, args):
        return dict(n_components=args.n_components,
                    n_regimes=args.n_regimes,
                    alpha=args.learning_rate,
                    compression=args.compression)

    def save(self, fp):
        # Log
        np.savetxt(fp + "/O.txt", self.O)
        np.savetxt(fp + "/R.txt", self.R)
        # Factors
        np.save(fp + "/W.npy", self.W)
        for i, U in enumerate(self.U):
            np.savetxt(fp + f"/U{i}.txt", U)

    def save_full_model(self, fp):
        with open(fp + 'model.pkl', 'wb') as f:
            pickle.dump(self, f)


def regime_compression(W):
    m = len(W)
    optm = None
    prev = np.inf

    # for i, Wi in enumerate(W):
    #     plt.plot(Wi)
    #     plt.savefig(f"W{i}.png")
    #     plt.close()

    for i in reversed(range(1, m)):
        # print(i)
        model = NMF(i, init="random", max_iter=500,
                    # l1_ratio=1,
                    regularization="transformation")

        Hnew = model.fit_transform(W.reshape((m, -1)))
        Wnew = model.components_
        # print((Hnew @ Wnew).shape, W.reshape((m, -1)).shape)
        cost = compute_coding_cost(W.reshape((m, -1)), Hnew @ Wnew)
        cost += compute_model_cost(Hnew)
        print("m=", i, "compression cost=", cost)

        if cost < prev:
            prev = cost
            optm = i
            Wopt = np.copy(Wnew)
            Hopt = np.copy(Hnew)
        else:
            break 

    print("optm=", optm, Wopt.shape)

    # 2. Compute the information loss

    loss = 0
    cost_all = []
    for i in range(m):
        # norm base
        # norms = np.linalg.norm(Wopt - W[i].reshape(-1), axis=1)
        # loss += np.log1p(min(norms))

        # cost base
        cost = []
        for j in range(optm):
            diff = (Wopt[j] - W[i].reshape(-1)).flatten().astype("float32")
            # prob = stats.norm.pdf(diff, loc=diff.mean(), scale=max(diff.std(), 1e-12))
            # cost.append(-1 * sum(np.log1p(prob)))
            # cost.append(np.sum(np.log1p(np.abs(diff))))
            # cost.append(np.linalg.norm(np.abs(diff)))
            cost.append(np.linalg.norm(diff))

        print(cost)
        cost_all.append(min(cost))

    print(cost_all, "all cost")
    Wopt = Wopt.reshape((optm, *W.shape[1:]))
    # plt.plot(Wopt[0])
    # plt.savefig("Wcmp.png")
    # plt.close()

    loss = sum(cost_all)
    print("compresstion loss=", loss)

    return Wopt, loss

def fit_seasonal_weights(X, U, V, W, zero=1e-12):

    W = deepcopy(W)
    # W = np.random.rand(X.shape[-1], U.shape[-1])
    accums = (U.T @ U) * (V.T @ V)
    mttkrp = unfolding_dot_khatri_rao(X, (None, [U, V, W]), 2)
    num = mttkrp.clip(min=zero)
    den = np.dot(W, accums).clip(min=zero)
    W = W * num / den
    # W = num / den
    plt.plot(W)
    plt.savefig("W.png")
    plt.close()
    return W

