import torch
from tqdm import tqdm
from utility import*
from torch.utils.data import DataLoader

@torch.no_grad()
def batch_pmf_sgd(replay_buffer,K,n,gamma,theta:torch.Tensor,F,alpha,state_space,w=None,w_star=None,pi=None,writer=None):
    '''Batched original version of semi-gradient descent algorithm. The only difference is a C on the right of the Aw-b.
    '''
    C = torch.tril(torch.ones((K-1,K-1)))
    S = torch.zeros([K,K-1])
    S[:(K-1),:] = torch.eye(K-1)
    S[-1,:] = -torch.ones(K-1) 
    stride = (theta[-1] - theta[0])/(K - 1)
    loss_list = torch.zeros(replay_buffer.shape[0])
    pmfloss_list = torch.zeros(replay_buffer.shape[0])
    value_loss_list = torch.zeros(replay_buffer.shape[0])
    if w is None:
        w = torch.zeros([n,K-1])
    w_ema = w
    if w_star == None:
        w_star,E_A,E_b,E_phi_phi,Apmf = info(pi,F,gamma,stride,alpha,theta,p=2,verbose = True)
        temp = F@w_star
        p = temp + 1/K
        real_value = p@(theta.T)
    batch_size = 25
    batched_rb = DataLoader(replay_buffer,batch_size=batch_size)

    for i, batch in tqdm(enumerate(batched_rb)):
        x,next_x,reward = batch[:,0].view((-1,1)),batch[:,1].view((-1,1)),batch[:,2].view((-1,1))
        x = x.long()
        next_x = next_x.long()
        TD = reward[:,:,None]+gamma*theta[None,None,:] - theta[None,:,None]
        zeta = h(TD / stride).squeeze()
        G = torch.einsum('bij,jk->bik', zeta[:, :-1], S)
        phi = F[x].squeeze(1)  
        phi_next = F[next_x].squeeze(1)
        Phi1 = torch.einsum('bi,bj->bij', phi, phi_next)
        Phi2 = torch.einsum('bi,bj->bij', phi, phi)
        g_row_sum = (zeta[:,:-1]).sum(dim=-1)
        Aw = Phi2@(w@(alpha*C.T))-Phi1@(w@G.transpose(-1,-2)@(alpha*C.T))
        b = 1/K*(phi[:,:,None]@((g_row_sum-1)[:,None,:]@(alpha*C.T)))
        ctd = torch.einsum('bij,jk->bik', Aw-b, C)
        ctd_avg = ctd.mean(dim=0)
        w = w - ctd_avg
        w_ema = 1/(i+1)*w + i/(i+1)*w_ema
        w_ = -w_ema.sum(dim=1,keepdim=True)
        w_mat = torch.concat((w_ema,w_),dim=1)
        C1 = torch.tril(torch.ones((K,K)))
        Cp_difference = C1@(w_mat-w_star).T@F.T
        value = (F@w_mat+1/K)@(theta.T)
        loss = torch.einsum('ij,ij->j',Cp_difference,Cp_difference)@pi.stationary
        loss = loss/(K-1)
        if -torch.log(loss) == torch.inf:
            print(f"The step size alpha is too large to converge. The algorithm exited at {i}-th iteration.")
            break
        loss_list[i] = loss
        w_vec = ((w_mat-w_star).T).reshape([(K)*n])[:(K-1)*n]
        losspmf = torch.norm(Apmf@w_vec)**2/(K-1)
        pmfloss_list[i] = losspmf
        value_loss = torch.norm(real_value-value)
        value_loss_list[i] = value_loss 
        if writer != None:
            writer.add_scalar('loss',loss,i)
            writer.add_scalar('-log_loss',-torch.log(loss),i)
            writer.add_scalar('pmf_loss',losspmf,i)
    phi = F
    temp = phi@w_mat
    p = temp + 1/K
    return w_mat, loss_list, value_loss_list, pmfloss_list



def batch_cdf_sgd(replay_buffer,K,n,gamma,theta:torch.Tensor,F,alpha,state_space,w=None,w_star=None,pi=None,writer=None):

    '''Batched preconditioned version of semi-gradient descent algorithm. We write it in its matrix form instead of the vectorized form to reduce computation costs.
    '''
    S = torch.zeros([K,K-1])
    S[:(K-1),:] = torch.eye(K-1)
    S[-1,:] = -torch.ones(K-1) 
    stride = (theta[-1] - theta[0])/(K - 1)
    loss_list = torch.zeros(replay_buffer.shape[0])
    pmfloss_list = torch.zeros(replay_buffer.shape[0])
    value_loss_list = torch.zeros(replay_buffer.shape[0])
    if w is None:
        w = torch.zeros([n,K-1])
    w_ema = w
    
    if w_star == None:
        w_star,E_A,E_b,E_phi_phi,Apmf = info(pi,F,gamma,stride,alpha,theta,p=2,verbose = True)
        temp = F@w_star
        p = temp + 1/K
        real_value = p@(theta.T)
    batch_size = 25
    batched_rb = DataLoader(replay_buffer,batch_size=batch_size)
    for i, batch in tqdm(enumerate(batched_rb)):
        x,next_x,reward = batch[:,0].view((-1,1)),batch[:,1].view((-1,1)),batch[:,2].view((-1,1))
        x = x.long()
        next_x = next_x.long()
        TD = reward[:,:,None]+gamma*theta[None,None,:] - theta[None,:,None]
        zeta = h(TD / stride).squeeze()
        G = torch.einsum('bij,jk->bik', zeta[:, :-1], S)
        phi = F[x].squeeze(1)  
        phi_next = F[next_x].squeeze(1)
        Phi1 = torch.einsum('bi,bj->bij', phi, phi_next)
        Phi2 = torch.einsum('bi,bj->bij', phi, phi)
        g_row_sum = (zeta[:,:-1]).sum(dim=-1)
        Aw = Phi2@(w*(alpha))-Phi1@(w@G.transpose(-1,-2)*(alpha))
        b = 1/K*(phi[:,:,None]@((g_row_sum-1)[:,None,:]*(alpha)))
        ctd = Aw - b
        ctd_avg = ctd.mean(dim=0)
        w = w - ctd_avg
        w_ema = 1/(i+1)*w + i/(i+1)*w_ema
        w_ = -w_ema.sum(dim=1,keepdim=True)
        w_mat = torch.concat((w_ema,w_),dim=1)
        C1 = torch.tril(torch.ones((K,K)))
        value = (F@w_mat+1/K)@(theta.T)
        Cp_difference = C1@(w_mat-w_star).T@F.T
        loss = torch.einsum('ij,ij->j',Cp_difference,Cp_difference)@pi.stationary
        loss_list[i] = loss
        loss = loss/(K-1)
        w_vec = ((w_mat-w_star).T).reshape([(K)*n])[:(K-1)*n]
        losspmf = torch.norm(Apmf@w_vec)**2/(K-1)
        pmfloss_list[i] = losspmf
        if -torch.log(loss) == torch.inf:
            print(f"The step size alpha is too large to converge. The algorithm exited at {i}-th iteration.")
            break
        value_loss = torch.norm(real_value-value)
        value_loss_list[i] = value_loss 
        if writer != None:
            writer.add_scalar('loss',loss,i)
            writer.add_scalar('-log_loss',-torch.log(loss),i)
            writer.add_scalar('value_loss',value_loss,i)
            writer.add_scalar('pmf_loss',losspmf,i)
    temp = F@w_mat
    p = temp + 1/K
    return w_mat, loss_list, value_loss_list,pmfloss_list
