import torch 
import torch.nn as nn 
import numpy as np 
import matplotlib.pyplot as plt 

from gfn.utils import Environment 

class LogReward(nn.Module): 

    def __init__(self, refs, beta=1.5, device='cpu'):
        super(LogReward, self).__init__() 
        self.device = device 
        self.refs = refs 
        self.beta = beta 
        
    @torch.no_grad() 
    def forward(self, batch_state): 
        dist_to_refs = (batch_state.pos[:, None, :] - self.refs).abs().sum(dim=-1).min(dim=-1).values
        return - self.beta * torch.log(dist_to_refs + 1)   

class LogRewardSparse(nn.Module): 

    def __init__(self, ro=1e-3, r1=.5, r2=2, device='cpu'): 
        super(LogRewardSparse, self).__init__() 
        self.ro = ro
        self.r1 = r1 
        self.r2 = r2 

    @torch.no_grad() 
    def forward(self, batch_state): 
        pos = batch_state.pos 
        r = self.ro + \
            self.r1 * (.25 < (pos / batch_state.width - .5).abs()).all(dim=1).long() + \
            self.r2 * ((.3 < (pos / batch_state.height - .5).abs()) \
                & ((pos / batch_state.height - .5).abs() < .4)).all(dim=1).long() 
        return torch.log(r)  

class Grid2D(Environment): 

    def __init__(self, width, height, batch_size, log_reward, device='cpu'):
        super(Grid2D, self).__init__(batch_size, width+height+1, log_reward, device=device) 
        self.width = width 
        self.height = height 
        self.topright = torch.tensor([width, height], dtype=torch.get_default_dtype(), device=device)  
        self.bottomleft = torch.tensor([0, 0], dtype=torch.get_default_dtype(), device=device) 
        
        self.pos = torch.zeros((2,), device=device).repeat(batch_size, 1) 

        self.actions = torch.tensor([[1, 0], [0, 1]], 
                                    dtype=torch.get_default_dtype(), device=device) 
        self.forward_mask = torch.ones((batch_size, self.actions.shape[0] + 1), device=device) 
        self.backward_mask = torch.zeros((batch_size, 2), device=self.device) 

    @torch.no_grad() 
    def update_forward_mask(self): 
        invalid_actions = torch.argwhere(torch.gt(self.pos[:, None, :] + self.actions, self.topright).any(dim=-1)).t() 
        umask = torch.ones_like(self.forward_mask) 
        umask[invalid_actions[0], invalid_actions[1]] = 0 
        self.forward_mask = umask.clone() 

    @torch.no_grad() 
    def update_backward_mask(self): 
        valid_actions = torch.argwhere(torch.gt(self.pos, self.bottomleft.view(1, -1))).t()  
        
        umask = torch.zeros_like(self.backward_mask) 
        umask[valid_actions[0], valid_actions[1]] = 1 
        self.backward_mask = umask.clone() 

    @torch.no_grad() 
    def apply(self, indices): 
        # Update the states 
        is_stop_action = (indices == 2) 
        actions = self.actions[indices[~is_stop_action]] 
        upos = self.pos.clone() 
        upos[~is_stop_action] += actions 
        self.pos = upos.clone() 

        # Update the mask 
        self.update_forward_mask() 
        self.update_backward_mask() 
        umask = self.forward_mask.clone() 
        umask[is_stop_action, :-1] = 0 
        self.forward_mask = umask.clone() 
        
        # Update the state 
        self.stopped = is_stop_action.long()  
        self.is_initial = (self.pos == 0).all(dim=1).long()
    
    @torch.no_grad() 
    def backward(self, indices): 
        stopped = (self.stopped == 1) 
        is_non_initial = (self.is_initial == 0) & ~stopped 
        indices_non_initial = indices[is_non_initial] 
        self.pos[is_non_initial] -= self.actions[indices_non_initial]   
        
        # Update state 
        self.is_initial = (self.pos == 0).all(dim=1).long() 
        self.stopped[:] = 0. 

        # Update masks 
        self.update_forward_mask() 
        self.update_backward_mask() 
        return indices

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.pos = torch.vstack([self.pos, batch_state.pos]) 
        self.forward_mask = torch.vstack([self.forward_mask, batch_state.forward_mask]) 

    @property 
    def unique_input(self): 
        return self.pos 