import numpy as np
from typing import Union, Optional
from pykronecker import KroneckerSum
import sys
sys.path.append('./')
from est.models.Model import Model


class Model2(Model):
    """
    Model2:
    parameters: A, G_i for i = 1,...,m
    dX_t = AX_tdt + \sigma_{k=1}^m G_k X_t dW_{k,t}
    """
    def __init__(self,
                dim: int,
                m: int):
        super().__init__(dim=dim, m=m)
  
    def drift(self, x:Union[float, np.ndarray])-> Union[float, np.ndarray]:
        return self._params[:self._dim ** 2].reshape(self._dim,self._dim) @ x
    
    def diffusion(self, x:Union[float, np.ndarray])-> Union[float, np.ndarray]:
        Gs = self._params[self._dim ** 2: ].reshape(self._m, self._dim, self._dim)
        
        diff_coef = (Gs[0] @ x).reshape(self._dim, 1)
        for i in range(1, self._m):
            diff_coef = np.hstack((diff_coef, (Gs[i] @ x).reshape(self._dim, 1)))
        return diff_coef
    
    def check_condition1(self, x0:Union[float, np.ndarray]):
        d = self._dim
        x0 = x0.reshape(self._dim, 1)
        A = self._params[:d ** 2].reshape(d, d)
        i = 0
        M = x0
        w = A @ x0
        while i < d - 1:
            M = np.hstack((M, w))
            w = A @ w
            i += 1
        if np.linalg.matrix_rank(M) == d:
            return True
        else:
            return False

    
    def check_condition2(self, x0:Union[float, np.ndarray]):
        d = self._dim
        l = (d ** 2 + d)/2
        m = self._m
        x0 = x0.reshape(self._dim, 1)
        v = (x0 @ x0.T).reshape(d ** 2, 1)
        A = self._params[:d ** 2].reshape(d, d)
        Gs = self._params[d ** 2: ].reshape(m, d, d) 
        
        # Kronecker sum of A and A
        A_ksum = KroneckerSum([A, A])
        I = np.diag(np.ones(d ** 2))
        A_ksum = A_ksum @ I
        
        # Kronecker product of G_i, for i = 1, ..., m
        Gs_kpro = np.zeros((d ** 2, d ** 2))
        for i in range(m):
            Gs_kpro += np.kron(Gs[i], Gs[i])
        
        huaA = A_ksum + Gs_kpro

        i = 0
        M = v
        w = huaA @ v
        while i < l - 1:
            M = np.hstack((M, w))
            w = huaA @ w
            i += 1
        print(f'rank: {l}, {np.linalg.matrix_rank(M)}')
        if np.linalg.matrix_rank(M) == l:
            return True
        else:
            return False
        