import torch 
import torch.nn as nn 
from sal.utils import Environment, sample_from_dirichlet_multinomial 
# from sal.gflownet import GFlowNet 

from typing import List 

import math

@torch.no_grad()
def sample_state_on_depth(state: 'Hypergrid', depth: int, model_idx: int, num_models: int, inplace: bool = False):
    # Sample the first coordinate
    lwr = depth * (model_idx / num_models) ** (1 / (state.dim - 1))
    upr = depth * ((model_idx + 1) / num_models) ** (1 / (state.dim - 1)) + (model_idx == num_models - 1)
    x1 = torch.randint(math.ceil(lwr), math.floor(upr), (state.batch_size,), device=state.device)

    max_val = depth - x1
    remaining_states = sample_from_dirichlet_multinomial(max_val,
                                                         concentration=torch.ones((state.dim - 1,), device=state.device))
    states = torch.hstack(
        [x1.view(-1, 1), remaining_states]
    )

    if inplace:
        state.state = states
        state.is_initial[:] = 0.
        state.max_trajectory_length -= depth
        state.update_mask()

    return states

@torch.no_grad()
def state_to_node(state: 'Hypergrid', num_models: int):
    if (state.cur_depth > state.max_depth).all():
        return state.node_indices
    shallow_states = (state.cur_depth == state.max_depth)
    x1 = state.state[shallow_states, 0]
    indices = (
        num_models * (x1 / state.max_depth) ** (state.dim - 1)
    ).floor().long()
    state.node_indices[shallow_states] = torch.where(indices == num_models, num_models - 1, indices)
    return state.node_indices

class LogReward(nn.Module):

    def __init__(self, ro: float = 1e-3):
        super(LogReward, self).__init__()
        self.ro = ro

    @torch.no_grad()
    def forward(self, batch_state: 'Hypergrid'):
        ax = torch.abs(batch_state.state / (batch_state.H - 1) * 2 - 1)  
        return ((ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + 1e-3).log()
        # state = (batch_state.state / (batch_state.H - 1) - 0.5).abs()
        # return torch.log(
        #     self.ro + 0.5 * ((0.25 < state) & (state <= 0.5)).all(dim=1) + 2 * ((0.3 < state) & (state < 0.4)).all(dim=1)
        # )

class LogRewardModel(nn.Module): 

    def __init__(self, log_reward_base: LogReward, device: str = 'cpu'): 
        super(LogRewardModel, self).__init__() 
        self.log_reward = log_reward_base 
        self.device = device 

    @torch.no_grad() 
    def forward(self, batch_state: 'Hypergrid', gflownets: List[nn.Module]): 
        log_rewards = list() 
        node_indices = state_to_node(batch_state, len(gflownets))  
        
        log_rewards_base = self.log_reward(batch_state) 
        log_rewards_model = torch.empty_like(log_rewards_base) 
        
        depth_mask = batch_state.cur_depth < batch_state.max_depth 

        for model_idx in range(len(gflownets)): 
            model_mask = (node_indices == model_idx) & ~depth_mask

            if model_mask.any(): 
                masked_states = batch_state.state.to(torch.get_default_dtype())[model_mask]
                log_rewards_model[model_mask] = gflownets[model_idx].pf.mlp_flows(masked_states).squeeze(dim=-1) 

        return torch.where(depth_mask, log_rewards_base, log_rewards_model) 
    
        # for (state, cur_depth, node_idx, log_reward) in zip(
        #     batch_state.state, batch_state.cur_depth, node_indices, log_rewards_base
        # ): 
        #     if cur_depth < batch_state.max_depth: 
        #         log_rewards.append(
        #             log_reward
        #         )
        #     else: 
        #         log_rewards.append(
        #             gflownets[node_idx].pf.mlp_flows(state.to(torch.get_default_dtype())).squeeze()  
        #         )
        #     assert not cur_depth > batch_state.max_depth 
 
        # return torch.hstack(log_rewards) 

class Hypergrid(Environment): 

    def __init__(self, dim: int, H: int, batch_size: int, log_reward: nn.Module, 
                 max_depth: int = None, max_depth_sample: int = None, device: torch.device = 'cpu'): 
        super(Hypergrid, self).__init__(batch_size, H * dim, log_reward, device=device)
        self.dim = dim 
        self.H = H 
        self.max_depth = max_depth 
        self.max_depth_sample = max_depth_sample 
        if self.max_depth is None: 
            self.max_depth = torch.inf  
        if self.max_depth_sample is None: 
            self.max_depth_sample = torch.inf  
        
        self.state = torch.zeros(
            (self.batch_size, self.dim), device=self.device  
        )
        self.actions = torch.vstack([
            torch.eye(self.dim), torch.zeros((1, self.dim))
        ]).to(self.device) 

        self.forward_mask = torch.ones((self.batch_size, self.dim + 1), device=self.device) 
        self.backward_mask = torch.zeros((self.batch_size, self.dim), device=self.device) 

        self.node_indices = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device) 
        
    @torch.no_grad() 
    def update_mask(self): 
        umask = self.forward_mask.clone()
        umask[:, :-1] = (self.state < self.H - 1) & (self.stopped != 1).view(-1, 1) & (self.cur_depth < self.max_depth_sample).view(-1, 1) 
        self.forward_mask = umask.clone() 
        self.backward_mask = (self.state != 0)   

    @torch.no_grad() 
    def apply(self, indices): 
        self.state = self.state + self.actions[indices] 
        self.stopped = (indices == self.dim) 
        self.is_initial[:] = 0. 
        self.update_mask() 
        
    @torch.no_grad() 
    def backward(self, indices): 
        self.state = self.state - self.actions[indices]         
        self.stopped[:] = 0. 
        self.is_initial = (self.state == 0).all(dim=1) 
        self.update_mask() 
        return indices 

    @property 
    def unique_input(self): 
        return self.state 

    @property 
    def cur_depth(self): 
        return self.state.sum(dim=1) 
    
    @torch.no_grad() 
    def merge(self, batch_state: 'Hypergrid'): 
        super().merge() 
        self.state = torch.vstack(self.state, batch_state.state) 

    @staticmethod 
    def create_env_on_depth(config, log_reward, model_idx): 
        env = Hypergrid(config.dim, config.H, batch_size=config.batch_size, log_reward=log_reward, device=config.device)
        sample_state_on_depth(env, depth=config.max_depth, model_idx=model_idx, 
                                  num_models=config.num_models, inplace=True) 
        return env 
    
    @staticmethod 
    def create_env_maximum_depth(config, log_reward): 
        env = Hypergrid(config.dim, config.H, batch_size=config.batch_size, log_reward=log_reward, device=config.device)
        env.max_depth = config.max_depth
        env.max_depth_sample = config.max_depth     
        return env 
    
    @staticmethod 
    def create_env_for_sal(config, log_reward): 
        env = Hypergrid(config.dim, config.H, batch_size=config.batch_size, log_reward=log_reward, device=config.device)
        env.max_depth = config.max_depth
        return env 

    @torch.no_grad() 
    def get(self, indices): 
        copy_self = super().get(indices) 
        copy_self.state = self.state[indices] 
        copy_self.forward_mask = self.forward_mask[indices] 
        copy_self.backward_mask = self.backward_mask[indices] 
        return copy_self 
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 
        self.forward_mask = torch.vstack([self.forward_mask, batch_state.forward_mask]) 
        self.backward_mask = torch.vstack([self.backward_mask, batch_state.backward_mask]) 
        pass 