import torch 
import torch.nn as nn 
import itertools 
import copy 
import math 

from sal.utils import Environment 
from scipy.stats import poisson, uniform
from scipy.special import binom 
 
ENCODING_METHOD = 'ranking-partition' 

def subset_rank(state: 'Set'): 
    subset = state.state
    subset_size = subset.sum(dim=1) 
    rank = torch.zeros((subset.shape[0],), device=subset.device)   

    for idx in range(subset.shape[1]): 
        is_idx_in_subset = subset[:, idx] == 1 
        rank += is_idx_in_subset.long() * binom(
            subset.shape[1] - (idx + 1), subset_size.cpu()   
        ).to(subset_size.device) 
        subset_size -= is_idx_in_subset.long() 
    return rank 

def subset_unrank(indices: torch.Tensor, subset_size: int, src_size: int): 
    states = torch.zeros((indices.shape[0], src_size), device=indices.device) 
    subset_size = torch.ones_like(indices) * subset_size
    
    for idx in range(src_size): 
        idx_rank = binom(
            src_size - (idx + 1), subset_size.cpu().clone()      
        ).to(subset_size.device, dtype=indices.dtype) 
        contains_idx = indices >= idx_rank
        states[contains_idx, idx] = 1
        indices = indices - (contains_idx * idx_rank).type(indices.dtype)  
        subset_size = subset_size - contains_idx.type(subset_size.dtype) 

    return states 

def subset_binary(state: 'Set'): 
    binary_base = 2 ** torch.arange(state.max_depth, device=state.device) 
    indices = (
        state.unique_input[:, :state.max_depth] * binary_base
    ).sum(dim=1) 
    return indices 

def subset_unbinary(indices, subset_size, src_size): 
    # Convert to the binary representation
    states = binary(indices, subset_size) 

    # Define the remaining states and shuffle them 
    number_remaining_states = subset_size - states.sum(dim=1) 
    arange = torch.arange(src_size - subset_size, device=indices.device) 
    remaining_states = number_remaining_states.view(-1, 1) > arange.view(1, -1) 
    
    # Shuffle the resulting matrix 
    perm = torch.rand(indices.shape[0], src_size - subset_size, device=indices.device).argsort(dim=1) 
    remaining_states = torch.gather(remaining_states, 1, perm) 

    return torch.hstack(
        [states, remaining_states] 
    )

def state_to_node(state: 'Set', num_compute_nodes, method: str = ENCODING_METHOD): 
    # We use only the first `depth` indices to decide upon state's node 
    if (state.cur_depth > state.max_depth).all():
        return state.node_indices 
    if (state.cur_depth < state.max_depth).all(): 
        return None 
    
    if method == 'binary': 
        indices = subset_binary(state)  
    elif method == 'ranking': 
        indices = subset_rank(state)  
    elif method == 'ranking-partition': 
        indices = subset_rank(state) 
        bin_size = binom(state.src_size, state.cur_depth.cpu()).to(indices.device) // num_compute_nodes 
        return torch.minimum(
            (indices / bin_size).long(), torch.ones_like(indices).long() * (num_compute_nodes - 1) 
        )    

    return indices % num_compute_nodes  
    # pass 

def binary(x: torch.Tensor, bits):
    mask = 2**torch.arange(bits).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).type(x.dtype) 

def sample_from_truncated_poisson(min_val, max_val, shape, mu=5., device='cpu'): 
    max_cutoff = poisson.cdf(max_val, mu) 
    min_cutoff = poisson.cdf(min_val, mu)
    # generate uniform distribution [0, cutoff):
    u = uniform.rvs(loc=min_cutoff, scale=max_cutoff-min_cutoff, size=shape)
    # convert to Poisson:
    truncated_poisson = poisson.ppf(u, mu)
    return torch.tensor(
        truncated_poisson, device=device, dtype=torch.long 
    ) 

def sample_state_on_depth(state: 'Set', depth, node_idx: int, num_nodes: int, inplace: bool = False, method: str = ENCODING_METHOD):
    # Define the interval of realizable quotients 
    
    # Sample from the corresponding distribution 
    if method == 'binary': 
        smallest_quotient = math.floor((2 ** max(2 * depth - state.src_size, 0) - 1) / num_nodes) 
        largest_quotient = math.floor((2 ** depth - 1) / num_nodes) 

        states_indices = num_nodes * sample_from_truncated_poisson(
        min_val=smallest_quotient, max_val=largest_quotient, shape=(state.batch_size,), device=state.device
        ) + node_idx 
        states = subset_unbinary(states_indices, depth, state.src_size)  
    elif method == 'ranking': 
        states_indices = num_nodes * torch.randint(
            low=0, high=int(binom(state.src_size, depth) - node_idx) // num_nodes, size=(state.batch_size,), device=state.device
        ) + node_idx 
        states = subset_unrank(states_indices, depth, state.src_size) 
        assert (states.sum(dim=1) == depth).all(), states.sum(dim=1)   
    elif method == 'ranking-partition':
        # Partition the integer-encoding into equally spaced sets 
        max_value = binom(state.src_size, depth)   
        bin_size = (max_value // num_nodes)
        low, high = bin_size * node_idx, bin_size * (node_idx + 1)   
        if node_idx == num_nodes: 
            high = max_value 
        states_indices = torch.randint(low=int(low), high=int(high), size=(state.batch_size,), device=state.device) 
        states = subset_unrank(states_indices, depth, state.src_size) 
        pass 

    # Return the sampled states 
    if inplace: 
        state.state = states
        state.max_trajectory_length = state.set_size - depth 
        state.update_mask() 
    return states 

class LogReward(nn.Module): 
    
    def __init__(self, src_size, seed, config, temperature=1., device='cpu', shift=0.): 
        super(LogReward, self).__init__() 
        self.src_size = src_size 
        self.seed = seed 
        self.device = device 
        g = torch.Generator(device=device) 
        g.manual_seed(seed) 

        self.values = torch.rand((self.src_size,), device=self.device, generator=g)
        # self.values = torch.sort(self.values, descending=True).values 
        if config.hide_important_region: 
            self.values[config.force_mask_idx] = 25 # torch.sign(self.values[config.force_mask_idx]) * 25  
        self.shift = shift 

        self.expected_reward = None 
        self.numerical_shift = None 
        self.temperature = temperature

    def forward(self, batch_state): 
        log_reward = (self.values * batch_state.unique_input).sum(dim=1) 
        return (log_reward - self.shift) / (5. * self.temperature)  

    def compute_avg_reward(self, state, max_batch_size): 
        all_log_r = list() 
        
        for env in state.list_all_states(max_batch_size):
            all_log_r.append(
                env.log_reward()  
            ) 
        
        all_log_r = torch.hstack( all_log_r ) 
        self.numerical_shift = all_log_r.max() 
        
        shifted_reward = ( all_log_r - self.numerical_shift ).exp()
        self.expected_reward = (shifted_reward ** 2).sum() / shifted_reward.sum() 

class LogRewardModel(nn.Module): 

    def __init__(self, log_reward_base: LogReward, device: str = 'cpu'): 
        super(LogRewardModel, self).__init__() 
        self.log_reward = log_reward_base 
        self.device = device 

    @torch.no_grad() 
    def forward(self, batch_state, gflownets): 
        log_rewards = list() 
        num_models= len(gflownets) 
        node_indices = state_to_node(batch_state, num_models)

        log_rewards_base = self.log_reward(batch_state)         
        log_rewards_model = torch.empty_like(log_rewards_base) 

        depth_mask = torch.ones_like(log_rewards_base, dtype=bool) * (batch_state.cur_depth < batch_state.max_depth)  
            
        for model_idx in range(num_models): 
            model_mask = (node_indices == model_idx) & ~depth_mask 
            if model_mask.any(): 
                masked_states = batch_state.state.to(torch.get_default_dtype())[model_mask]
                log_rewards_model[model_mask] = gflownets[model_idx].pf.mlp_flows(masked_states).squeeze(dim=-1) 

        log_rewards_all = torch.where(depth_mask, log_rewards_base, log_rewards_model) 
        return log_rewards_all 

class Set(Environment): 

    def __init__(self, src_size, set_size, batch_size, log_reward, device='cpu'): 
        super(Set, self).__init__(batch_size, set_size, log_reward, device=device)
        self.src_size = src_size 
        self.set_size = set_size 
        self.state = torch.zeros((self.batch_size, self.src_size), device=self.device, dtype=int)  
        self.forward_mask = torch.ones((self.batch_size, self.src_size), device=self.device) 
        self.backward_mask = torch.zeros((self.batch_size, self.src_size), device=self.device) 
        self.max_num_parents = src_size 
        
        self.max_depth = self.set_size 
        self.max_depth_sample = self.set_size 
        self.node_indices = torch.zeros((self.batch_size,), device=self.device, dtype=torch.long) 

    @torch.no_grad() 
    def update_mask(self): 
        self.forward_mask = 1 - self.state.type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def apply(self, indices): 
        s = self.state.clone()
        s[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.state = s.clone() 
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.max_depth_sample) # == set_size by default   
        self.update_mask() 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.update_mask() 
        return indices 

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 
        self.forward_mask = torch.vstack([self.forward_mask, batch_state.forward_mask]) 
        self.backward_mask = torch.vstack([self.backward_mask, batch_state.backward_mask]) 
    
    @torch.no_grad() 
    def list_all_states(self, max_batch_size=None): 
        factorial = lambda n: n if n == 1 else n * factorial(n - 1) 
        newton_bi = lambda n, k: factorial(n) // ( factorial(n - k) * factorial(k) ) 

        total_states = newton_bi(self.src_size, self.set_size) 
        if max_batch_size is None: 
            max_batch_size = total_states 
        # assert total_states < 1e7, f'too many states: {total_states}' 
        # Generate all states
        indices = list()  
        visited_states = 0 
        for comb in itertools.combinations(torch.arange(self.src_size), r=self.set_size): 
            indices.append(comb) 
            visited_states += 1 
            if (visited_states % max_batch_size) == 0 or visited_states == total_states:
                stes = torch.zeros((len(indices), self.src_size), device=self.device)  
                stes[torch.arange(len(indices)).view(-1, 1).repeat(1, self.set_size), indices] = 1. 

                print(f'{visited_states}/{total_states}', max_batch_size) 
                # Update the states' attributes 
                self.batch_size = len(indices) 
                self.state = stes 
                self._update_when_batch_size_changes() 
                indices = list() 
                yield self 
    
    def _update_when_batch_size_changes(self): 
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = self.src_size * torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.ones((self.batch_size), device=self.device)
        self.is_initial = torch.zeros((self.batch_size,), device=self.device)

        # Use >= for compatiblity with `Sets` 
        self.forward_mask = 1 - (self.state >= 1.).type(self.forward_mask.dtype) 
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 

    @property 
    def unique_input(self): 
        return self.state.type(self.backward_mask.dtype)   

    @torch.no_grad() 
    def get_children(self, return_actions=False):         
        actions = torch.arange(self.src_size, device=self.device) 

        for action in actions: 
            child = copy.deepcopy(self) 
            curr_actions = action * torch.ones((self.batch_size,), device=self.device, dtype=int)  
            child.apply(
                curr_actions 
            )
            if return_actions: 
                yield child, curr_actions  
            else: 
                yield child 
    
    @torch.no_grad() 
    def get_parents(self): 
        # Each element corresponds to a parent 
        _, actions = torch.where(self.state == 1.) # curr_size * batch_size 
        curr_size = self.state.sum(dim=1)[-1].int() 

        for i in range(curr_size): 
            backward_actions = actions[torch.arange(i, actions.shape[0], step=curr_size)]   
            parent = copy.deepcopy(self) 
            parent.backward(
                backward_actions
            )
            yield parent, backward_actions  

    # This is the same for both Sets and Bags 
    def get(self, indices):
        copy_self = super().get(indices) 
        copy_self.state = self.state[indices] 
        copy_self.forward_mask = self.forward_mask[indices] 
        copy_self.backward_mask = self.backward_mask[indices] 
        return copy_self 

    @property 
    def cur_depth(self): 
        return self.unique_input.sum(dim=1) 

    @staticmethod 
    def create_env_on_depth(config, log_reward, model_idx): 
        env = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        sample_state_on_depth(env, depth=config.max_depth, node_idx=model_idx, num_nodes=config.num_models, inplace=True) 
        return env 
        
    @staticmethod 
    def create_env_maximum_depth(config, log_reward): 
        env = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        env.max_depth_sample = env.max_depth = config.max_depth    
        return env 
    
    @staticmethod 
    def create_env_for_sal(config, log_reward): 
        env = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        env.max_depth = config.max_depth 
        return env 
    
class Bag(Set): 

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.set_size)
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 
        return indices 
    
    @torch.no_grad() 
    def list_all_states(self, max_batch_size=None): 
        factorial = lambda n: n if n == 1 else n * factorial(n - 1) 
        newton_bi = lambda n, k: factorial(n) / (factorial(n - k) * factorial(k)) 
        total_num_bags = newton_bi(self.src_size + self.set_size - 1, self.set_size) 
        assert total_num_bags < 5e6  

        max_batch_size = total_num_bags if max_batch_size is None else max_batch_size 
        visited_states = list() 
        num_visited_states = 0 

        for comb in itertools.combinations_with_replacement(range(self.src_size), self.set_size): 
            visited_states.append(comb) 
            num_visited_states += 1 
            if (
                num_visited_states % max_batch_size == 0 or 
                num_visited_states + max_batch_size >= total_num_bags 
            ): 
                self.batch_size = len(visited_states) 
                self.state = torch.zeros((self.batch_size, self.src_size), device=self.device) 
                visited_states = torch.tensor(visited_states, device=self.device, dtype=torch.int64) 
                self.state.scatter_add_(dim=1, 
                                        index=visited_states, 
                                        src=torch.ones_like(visited_states).type(self.state.dtype))
                self._update_when_batch_size_changes() 
                visited_states = list() 
                yield self 

if __name__ == '__main__': 
    # Test state_to_node 
    depth = 8 
    num_nodes = 5
    for node_idx in range(num_nodes): 
        states = Set(32, 16, 128, None, device='cpu')
        states.max_depth = depth   
        sample_state_on_depth(states, depth=depth, node_idx=node_idx, num_nodes=num_nodes, inplace=True)
        indices = state_to_node(states, num_nodes) 
        # print(states.unique_input) 
        assert (indices == node_idx).all(), (
            indices, node_idx 
        ) 

    # Relatively small examples for testing 
    depth = 2 
    num_nodes = 3 
    all_subsets = torch.tensor([
        [1, 1, 0, 0],  
        [1, 0, 1, 0],
        [1, 0, 0, 1], 
        [0, 1, 1, 0], 
        [0, 1, 0, 1],
        [0, 0, 1, 1], 
    ]) 
    states = Set(4, 2, batch_size=len(all_subsets), log_reward=None, device='cpu')
    states.state = all_subsets 
    indices = state_to_node(states, num_nodes, method='ranking') 
    print(indices)  