import numpy as np
import cupy as cp
from scipy.linalg import expm, sinm, cosm
import scipy.linalg as la
from math import log
import math
import os,sys,time

class RelaxSolver():
    def __init__(self, n, d,c, P, I_t, D):
        self.n = n
        self.c = c
        self.P = P  # Fisher info matrix for all ndim: (n, dc, dc)
        self.I_t = I_t  # total fisher info matrix: (dc,dc)
        self.d = d
        self.dc = d*c

        self.D = D

        self.l = None #lambda
        self.k = None

    def f(self, S_inv):
        f_ = cp.einsum('ij,ji->', S_inv, self.I_t)
        #f_*= (1./self.n)
        return f_

    def df(self, S_inv):
        I_ = S_inv @ self.I_t @ S_inv

        A_ = self.P.transpose((0,2,1))
        A_ = (A_ @ I_) @ self.P
        df_ = cp.trace(A_, axis1=1, axis2=2)
        df_ = (self.k/(1. + self.l)) * df_
        return -df_

    def Sigma(self, omega):
        pi = omega + (self.l/self.n)
        pi *= (self.k/(1. + self.l))
        S = cp.einsum('i,ijk,kli->jl', pi, self.P, self.P.transpose((2,1,0)), optimize=True)
        S += self.D
        #S = np.einsum('i,ijk->jk', pi, self.I, optimize=True)
        return S

    def update(self, omega, eta, g):
        exponential = cp.exp(-eta* g)
        omega_new = cp.multiply(omega, exponential)
        return omega_new[:]/omega_new.sum()

    def run(self, l, k,iter_max, print_unit, tol):

        self.l = l
        self.k = k
        omega_0 = cp.full(self.n,  1./self.n)

        t = 1
        rel = 10.
        omega_avg = omega_0
        f = []
        while((t < iter_max+1) and (rel > tol)):
            # 1. get Sigma and Sigma_inv
            if (t%print_unit == 0) or (t==1):
                print('step :',t)
            S = self.Sigma(omega_0)
            S_inv = cp.linalg.inv(S)
            # 2. get gradient
            grad = self.df(S_inv)
            # 3. update omega
            eta_i = 1.e-3 
            if t>= 40:
                eta_i = 1.e-2 
            if t>=100:
                eta_i = 1.e-1  
            omega_0 = self.update(omega_0, eta_i, grad)
            # 4.
            f_value = self.f(S_inv)
            f.append(f_value)
            if t==1:
                print("f initial value:", f_value)
            if t>=2:
                df = f[-1] - f[-2]
                rel = abs(df/f[-1])
            if (t%print_unit == 0) and (t>=2):
                print(" eta_i:%.2e, df: %.4e, f: %.5e, rel: %.2e" % (eta_i, df, f[-1], rel))

            t += 1

        #print:
        if 1:
            out = cp.sort(omega_0)
            print("smallest 20 weights in relaxed solution:", out[:20])
            print("largest 20 weights in relaxed solution:", out[-20:])
        omega_0 = cp.asnumpy(omega_0)

        return omega_0


class SolverRound():
    def __init__(self, Dtilde, P, S):
        self.Dtilde =  Dtilde
        self.P = P
        self.S = S # Sigma*^{-1/2}
        self.n = P.shape[0]
        self.dc = P.shape[1]
        self.c1 = P.shape[2]

    def find_constant_l12(self, w):
        c_l = - cp.min(w)
        c_u = self.dc ** .5
        tol = 1.e-10
        i = 0
        while(abs(c_l - c_u) > tol):
            c = (c_l + c_u)/2.
            a = cp.power(c + w, -2)
            trace = cp.sum(a)
            if trace > 1.:
                c_l = c
            else:
                c_u = c
            i+=1
        return c


    def run_l12(self, k, eta,print_unit = 5):

        sel = []
        pool = np.arange(self.n)
        F = cp.zeros((self.dc, self.dc))

        for t in range(1, k+1):
            # 1. get At and At_half
            if (t==1):
                At_half = (self.dc**0.5) * cp.eye(self.dc)
            else:
                # eigendecompose: eta *F
                w, v = cp.linalg.eigh(eta * F)
                # find constant
                c_t = self.find_constant_l12(w)
                w_ = c_t + w
                At_half = (v @ cp.diag(w_)) @ v.T

            t0 = time.time()
            At_half += eta * self.Dtilde
            At_half = cp.linalg.inv(At_half)
            At = At_half @ At_half
            t1 = time.time()


            # 2. get t_At and t_At_half
            t_At = (self.S @ At) @ self.S
            t_At_half = (self.S @ At_half) @ self.S
            t2 = time.time()

            # 3. get trace(A[i] B[i]) for each i\in pool
            P = self.P[pool]
            # 3-1: get A[i] = I + \eta P_i^\top \tilde{A_t^{1/2}} P_i
            A = P.transpose((0,2,1))
            A = (A @ t_At_half) @ P
            A = cp.eye(self.c1) + (eta *A)
            # 3-2: get B[i] = P_i^\top \tilde{A_t} P_i
            B = P.transpose((0,2,1))
            B = (B @ t_At) @ P
            # 3-3: linear solve A X = B and get trace c[i] = trace(X[i])
            if 0:
                X = cp.linalg.solve(A,B)
                trace = cp.trace(X, axis1 = 1, axis2=  2)
            if 1:
                A = cp.linalg.inv(A)
                trace = cp.einsum('ijk,kji->i', A, B.transpose((2,1,0)), optimize= True)
            t3 = time.time()

            # 4. select and update
            pool_id = np.argmax(cp.asnumpy(trace))
            select = pool[pool_id]
            pool = np.delete(pool, pool_id)
            sel.append(select)

            Pi = self.S @ self.P[select]
            F = F + (Pi @ Pi.T) + self.Dtilde
            #monitor
            if print_unit > 0:
                if t % print_unit ==0:
                    w,_ = cp.linalg.eigh(F)
                    print("===== step %i, pool_id %i, eigen_min %.4e" % (t, select, cp.min(w)))


        print('number of samples selected:', len(sel))
        return sel


class BaitSolver():

    def __init__(self, nr, D, P, I_U, lamb=1.0):
        a = 1./(nr**0.5)
        self.P=  a * P
        self.n = P.shape[0]
        self.dc = P.shape[1]
        self.c1 = P.shape[2]
        self.lamb = lamb
        self.nr = nr
        self.M = lamb*cp.eye(self.dc) + D
        self.I_U = I_U


    def run(self, b):

        sel0 =[]
        pool = cp.arange(self.n)
        b1= 2*b +1

        # forward pass
        for t in range(1, b1):
            t0 = time.time()
            M_inv = cp.linalg.inv(self.M)
            N = M_inv @ self.I_U @ M_inv
            P = self.P[pool]
            t1 = time.time()
            # get B_i
            B = P.transpose((0,2,1))
            B = (B @ N) @ P
            # get A_i
            A = P.transpose((0,2,1))
            A = (A @ M_inv) @ P
            A += cp.eye(self.c1)
            t2 = time.time()
            # get trace
            A = cp.linalg.inv(A)
            trace = cp.einsum('ijk,kji->i', A, B.transpose((2,1,0)), optimize= True)
            t3 = time.time()

            # select
            pool_id = np.argmax(cp.asnumpy(trace))
            select = pool[pool_id]
            pool = np.delete(cp.asnumpy(pool), pool_id)
            pool = cp.asarray(pool)
            sel0.append(select)

            Pi = self.P[select]
            self.M += (Pi @ Pi.T)
            if t%10 ==0:
                print("forward greedy step, selected sample id:", t, select)
    

        #backward pass
        pool = cp.asarray(sel0)
        for t in range(1, b+1):
            M_inv = cp.linalg.inv(self.M)
            N = M_inv @ self.I_U @ M_inv
            P = self.P[pool]
            # get B_i
            B = P.transpose((0,2,1))
            B = (B@ N) @ P
            # get A_i
            A = P.transpose((0,2,1))
            A = (A @ M_inv) @ P
            A = cp.eye(self.c1) + A
            # get trace
            A = cp.linalg.inv(A)
            trace = cp.einsum('ijk,kji->i', A, B.transpose((2,1,0)), optimize= True)


            #delete
            pool_id = np.argmin(cp.asnumpy(trace))
            select = pool[pool_id]
            pool = np.delete(cp.asnumpy(pool), pool_id)
            pool = cp.asarray(pool)
            Pi = self.P[select]
            self.M -= (Pi @ Pi.T)


        return pool











