import numpy as np
import time
from scipy.linalg import svd
from scipy.optimize import minimize


class nosvd_RidgeLOOCV:
    
    def __init__(self, alphas_=np.logspace(-10, 10, 11, endpoint=True, base=10)):
        self.alphas_=alphas_
    
    
    def fit(self, u, s, v_trans, y, n, p, a_x, b_x):
        
        
        a_y, b_y = (y.mean(), y.std())
        y = (y - a_y)/b_y
        
        c = u.T.dot(y) * s
        r = u*s

        loo_mse = np.zeros_like(self.alphas_)
        for i in range(len(self.alphas_)):
            # hat = u.dot(np.diag(s**2/(s**2 + self.alphas[i]))).dot(u.T)
            # err = y - hat.dot(y)
            # loo_mse[i] = np.mean((err / (1 - np.diagonal(hat)))**2)
            z = u*(s**2/(s**2 + self.alphas_[i]))
            h = (z*u).sum(axis=1)
            # print('h', h.shape)
            beta = c/(s**2 + self.alphas_[i])
            err = y - r.dot(beta)
            loo_mse[i] = np.mean((err / (1 - h))**2)

        i_star = np.argmin(loo_mse)
        self.alpha_ = self.alphas_[i_star]

        beta = c / (s**2 + self.alpha_)
        beta = v_trans.T.dot(beta)
        self.sigma_square_ = loo_mse[i_star] * b_y**2
        self.coef_ = beta * b_y / b_x
        self.intercept_ = a_y - self.coef_.dot(a_x)
        self.iterations_ = len(self.alphas_)
        return self

    def predict(self, x):
        return x.dot(self.coef_) + self.intercept_


class nosvd_RidgeEM:
    
    def __init__(self, epsilon=0.00000001):
        self.epsilon = epsilon


    def fit(self, u, s, v_trans, y, n, p, a_x, b_x):
        
        a_y, b_y = (y.mean(), y.std())
        y = (y - a_y)/b_y
        
        y_sqnorm = y.dot(y)
        c = u.T.dot(y) * s
        beta = c/s**2
        tau_square = 1
        sigma_square = y.var()
        RSS = y_sqnorm - 2*beta.dot(c)+(beta*beta).dot(s*s)
        self.iterations_ = 0

        while True:
            RSS_old = RSS
            beta_old = beta
            beta = c / (s*s + 1/tau_square)

            w = beta.dot(beta) + sigma_square*((1/(s*s+1/tau_square)).sum()+tau_square*max(p-n, 0))
            
            RSS = y_sqnorm - 2*beta.dot(c)+(beta*beta).dot(s*s)
            z = RSS + sigma_square*(s*s/(s*s + 1/tau_square)).sum()

            tau_square = (w*(-1+n) - z*(1+p) + (4*w*(n+1)*z*(3+p)+(w+z*(p+1)-w*n)**2)**0.5) / (2*z*(3+p))
            sigma_square = (z*tau_square + w) / ((n+p+2)*tau_square)

            delta = abs(RSS_old - RSS).sum() / (1 + abs(RSS).sum())

            self.iterations_ += 1
            if  delta < self.epsilon:
                break

        beta = v_trans.T.dot(beta)
        
        self.coef_ = beta * b_y / b_x
        self.intercept_ = a_y - self.coef_.dot(a_x)
        self.sigma_square_ = sigma_square * b_y**2
        self.tau_square_ = tau_square
        self.alpha_ = 1/tau_square
        return self

    def predict(self, x):
        return x.dot(self.coef_) + self.intercept_


class RidgeMT:

    def __init__(self, epsilon=0.00000001, alphas=np.logspace(-10, 10, 11, endpoint=True, base=10), estimator = "EM"):
        self.epsilon = epsilon
        self.alphas = alphas
        self.classes = None
        self.estimator = estimator
        self.iterations_ = 0
        
    @staticmethod
    def alpha_range_GMLNET(x, y):
        n, p = x.shape
        # x_mu = x.mean(axis=0)
        # x_star = ((x - x_mu)/(1/n**0.5*np.sum((x - x_mu)**2, axis=0)))
        alpha_max = 1/((0.001)*n) * np.max(np.abs(x.T.dot(y)))
        alpha_min = 0.0001*alpha_max if n >= p else 0.01*alpha_max
        return alpha_min, alpha_max

    @staticmethod
    def alpha_log_grid(alpha_min, alpha_max, l=100, base=10.0):
        log_min = np.log(alpha_min) / np.log(base)
        log_max = np.log(alpha_max) / np.log(base)
        return np.logspace(log_min, log_max, l, endpoint=True)

    def fit(self, x, y):
        n, p = x.shape
        
        #a_x, a_y = (x.mean(axis=0), y.mean()) if self.fit_intercept else (np.zeros(p), 0.0)
        #b_x, b_y = (x.std(axis=0), y.std()) if self.normalize else (np.ones(p), 1.0)
        
        a_x, b_x = (x.mean(axis=0), x.std(axis=0))
        x = (x - a_x)/b_x

        svd_start_time = time.time()
        u, s, v_trans = svd(x, full_matrices=False)
        self.svdTime = time.time() - svd_start_time
        
        self.classes = np.unique(y)
        self.models = []
        iterations = np.zeros(len(self.classes))
            
        for i, c in enumerate(self.classes):
            y_c = np.zeros(len(y))
            y_c[y == c] = 1
            
            if self.estimator == "EM":
                model_c = nosvd_RidgeEM(epsilon = self.epsilon)
                
            else:
                if np.isscalar(self.alphas):
                    alpha_min, alpha_max = self.alpha_range_GMLNET(x, y_c)
                    self.alphas_ = self.alpha_log_grid(alpha_min, alpha_max, self.alphas)
                else:
                    self.alphas_ = self.alphas
                    
                model_c = nosvd_RidgeLOOCV(alphas_ = self.alphas_)
            
            model_c.fit(u, s, v_trans, y_c, n, p, a_x, b_x)
            self.models.append(model_c)
            iterations[i] = model_c.iterations_
            
        self.iterations_ = np.mean(iterations)

    def cl_predict(self, x):
        y_pred = np.zeros((x.shape[0], len(self.classes)))
        for i, model in enumerate(self.models):
            y_pred[:, i] = model.predict(x)
        return self.classes[np.argmax(y_pred, axis=1)]
    
    def score(self, X, y):
        y_pred = self.cl_predict(X)
        return np.mean(y_pred == y)
