import torch

def mpower(A,power=1/2):
    s,v,d = torch.linalg.svd(A)
    return s@torch.diag(v)**(power)@d

def h(x):
    if x.dim() == 2:
        x = x[None,:,:]
    x_ = x[:,0,:]
    x__ = x[:,-1,:]
    x = torch.max(torch.tensor(0),1-torch.absolute(x))
    x[:,0,:] = torch.min(torch.tensor(1),torch.max(torch.tensor(0),1-x_))
    x[:,-1,:] = torch.min(torch.tensor(1),torch.max(torch.tensor(0),1+x__))
    return x

def find_stationary(P):
    evals, evecs = torch.linalg.eig(P.T)
    evals = evals.real
    evecs = evecs.real
    evec1 = evecs[:,torch.isclose(evals, torch.tensor(1.0))]
    evec1 = evec1[:,0]
    stationary = evec1 / evec1.sum()
    return stationary

class Policy:
    def __init__(self,state_space,prob_of_action,dist_of_reward,reward_list,type_env='discrete'):
        self._S = state_space
        self.size = state_space.shape[0]
        self._P = prob_of_action
        self._R = dist_of_reward
        self.S = self._S
        self.P = self._P
        self.R = self._R
        self.reward_list = reward_list
        self.stationary = find_stationary(self._P)
        self.env_type = type_env
    def sample(self,sample_number):
        N = sample_number
        state_space = self._S
        if self.env_type == 'discrete':
            p_a = self._P
            dist_of_reward = self._R
            reward_list = self.reward_list
            stationary = self.stationary
            x_index = torch.multinomial(stationary,num_samples=sample_number,replacement=True)
            x = state_space[x_index]
            P = p_a[x_index]
            next_x_index = torch.multinomial(P,1,replacement=True).view([1,N]).squeeze()
            next_x = state_space[next_x_index]
            reward_dist = dist_of_reward[x_index,next_x_index] 
            reward_index = torch.multinomial(reward_dist,1,replacement=True).view(1,N).squeeze()
            reward = reward_list[reward_index]
            rb = torch.stack((x,next_x,reward)).transpose(0,1)
            return rb
        else:
            return
    

def random_Policy(s,r_number):
    S = torch.arange(0,s).long()
    temp = torch.rand((s,s))
    p_a = temp/temp.sum(dim=-1,keepdim=True)
    reward_list = torch.linspace(0,1,r_number)
    temp = torch.rand((r_number,r_number))
    temp = torch.rand(s,s,r_number)
    r = temp/temp.sum(dim=-1,keepdim=True)
    cfg = (S,p_a,r,reward_list)
    return Policy(S,p_a,r,reward_list),cfg


def ltd_info(pi:Policy,F,gamma):
    x = pi.S
    phi = F[x]   
    phi_phi = phi[:,:,None]*phi[:,None,:]
    P_x = pi.stationary
    P_x_next_x = P_x[:,None]*pi.P
    E_phi_phi = (P_x[:,None,None]*phi_phi).sum(dim=0)
    phi_phi_next = phi[None,:,None,:]*phi[:,None,:,None]
    E_phi_phi_next = (P_x_next_x[:,:,None,None]*phi_phi_next).sum(dim=(0,1))
    E_A = E_phi_phi-gamma*E_phi_phi_next
    r = pi.R
    reward_list = pi.reward_list
    E_r_x_next_x =  (reward_list[None,None,:]*r).sum(dim=-1)
    E_r_x = (pi.P*E_r_x_next_x).sum(dim=-1)
    E_b = (P_x[:,None]*(phi*E_r_x[:,None])).sum(dim=0)
    w = E_A.inverse()@E_b
    return w

def info(pi:Policy,F,gamma,stride,alpha,theta,p=2,verbose = True):
    reward_list = pi.reward_list
    K = theta.shape[0]
    n = F[0].shape[0]
    IKn = torch.ones((K*n))
    One = torch.ones((K,K))
    C = torch.tril(One)
    E = torch.eye(K)
    S = torch.zeros([K,K-1])
    S[:(K-1),:] = torch.eye(K-1)
    S[-1,:] = -torch.ones(K-1)  
    r = pi.R
    P_x = pi.stationary
    P_x_next_x = P_x[:,None]*pi.P
    P_x_next_x_r = P_x_next_x[:,:,None]*r
    x = pi.S
    size = pi.size
    phi = F[x]
    ECTC = E[:-1]@C.T@C
    K2 = ECTC@S
    temp = gamma*theta[None,:] - theta[:,None]
    zeta_r = h((reward_list[:,None,None]+temp[None,:,:])/stride)
    E_zeta_x_next_x = (r[:,:,:,None,None]*zeta_r[None,None,:,:,:]).sum(dim=2)
    In = torch.ones((n))
    phi_phi = phi[:,:,None]*phi[:,None,:]
    phi_I_x = phi[:,:,None]*In[None,None,:]
    E_phi_phi = (P_x[:,None,None]*phi_phi).sum(dim=0)
    phi_phi_next = phi[None,:,None,:]*phi[:,None,:,None]
    E_zeta_x = (pi.P[:,:,None,None]*E_zeta_x_next_x).sum(dim=1)
    E_K1_x_next_x = ECTC@E_zeta_x_next_x@S[None,None,:,:]
    E_K1 = (P_x_next_x[:,:,None,None]*E_K1_x_next_x).sum(dim=(0,1))
    K1_r = ECTC@zeta_r@S[None,:,:]
    E_A1_x_next_x = torch.einsum('ijab,ijcd->ijacbd',E_K1_x_next_x,phi_phi_next).view(size,size,(K-1)*n,(K-1)*n)
    A1_x_next_x_r = torch.einsum('ijkab,ijkcd->ijkacbd',K1_r[None,None,:,:,:],phi_phi_next[:,:,None,:,:]).view(size,size,reward_list.shape[0],(K-1)*n,(K-1)*n)
    A2_x = torch.einsum('iab,icd->iacbd',K2[None,:,:],phi_phi).view(size,(K-1)*n,(K-1)*n)
    A_x_next_x_r = -A1_x_next_x_r+A2_x[:,None,None,:,:]
    ATA_x_next_x_r = A_x_next_x_r.transpose(-1,-2)@A_x_next_x_r
    E_ATA = (P_x_next_x_r[:,:,:,None,None]*ATA_x_next_x_r).sum(dim=(0,1,2))
    E_A1_x= (P_x_next_x[:,:,None,None]*E_A1_x_next_x).sum(dim=1)
    E_A1= (P_x_next_x[:,:,None,None]*E_A1_x_next_x).sum(dim=(0,1))
    E_A2 = torch.kron(K2,E_phi_phi)
    E_A = -E_A1+E_A2
    E_A_AT = E_A+E_A.T
    E_B = E_A_AT-alpha*E_ATA
    B_x_next_x_r = A_x_next_x_r+A_x_next_x_r.transpose(-1,-2)-alpha*ATA_x_next_x_r
    B2_x_next_x_r = B_x_next_x_r@B_x_next_x_r
    E_B2 = (P_x_next_x_r[:,:,:,None,None]*B2_x_next_x_r).sum(dim=(0,1,2))
    E_A31_x = torch.einsum('iab,icd->iacbd',E@C.T@C@E_zeta_x,phi_I_x).view(size,K*n,K*n)
    E_A32_x = torch.einsum('iab,icd->iacbd',(E@C.T@C)[None,:,:],phi_I_x).view(size,K*n,K*n)
    E_A31 = (P_x[:,None,None]*E_A31_x).sum(dim=0)
    E_A32 = (P_x[:,None,None]*E_A32_x).sum(dim=0)
    E_A3 = E_A31-E_A32
    E_b = ((-1/(K*n)*E_A3)@IKn)[:-n]
    w = -E_A.inverse()@E_b
    w_mat = w.view([K-1,n]).T
    w_ = -w_mat.sum(dim=1,keepdim=True)
    w_mat = torch.concat((w_mat,w_),dim=1)
    C = torch.tril(torch.ones((K-1,K-1)))
    Apmf = torch.kron(C@C.T,torch.eye(n))@E_A

    return w_mat,E_A,E_b,E_phi_phi,Apmf


import torch







