import torch 
import torch.nn as nn 
import itertools 
import copy 

from gfn.utils import Environment 

class LogReward(nn.Module): 
    
    def __init__(self, src_size, seed, device='cpu', shift=0.): 
        super(LogReward, self).__init__() 
        self.src_size = src_size 
        self.seed = seed 
        self.device = device 
        g = torch.Generator(device=device) 
        g.manual_seed(seed) 

        self.values = torch.rand((self.src_size,), device=self.device, generator=g) 
        self.shift = shift 

        self.expected_reward = None 
        self.numerical_shift = None 

    def forward(self, batch_state): 
        log_reward = (self.values * batch_state.unique_input).sum(dim=1) 
        return (log_reward - self.shift) / 5. 

    def compute_avg_reward(self, state, max_batch_size): 
        all_log_r = list() 
        

        for env in state.list_all_states(max_batch_size):
            all_log_r.append(
                env.log_reward()  
            ) 
        
        all_log_r = torch.hstack( all_log_r ) 
        self.numerical_shift = all_log_r.max() 
        
        shifted_reward = ( all_log_r - self.numerical_shift ).exp()
        self.expected_reward = (shifted_reward ** 2).sum() / shifted_reward.sum() 
        

class Set(Environment): 

    def __init__(self, src_size, set_size, batch_size, log_reward, device='cpu'): 
        super(Set, self).__init__(batch_size, set_size, log_reward, device=device)
        self.src_size = src_size 
        self.set_size = set_size 
        self.state = torch.zeros((self.batch_size, self.src_size), device=self.device, dtype=int)  
        self.forward_mask = torch.ones((self.batch_size, self.src_size), device=self.device) 
        self.backward_mask = torch.zeros((self.batch_size, self.src_size), device=self.device) 
        self.max_num_parents = src_size 

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.set_size)
        self.forward_mask = 1 - self.state.type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.forward_mask = (1 - self.state).type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 
        return indices 

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 
    
    @torch.no_grad() 
    def list_all_states(self, max_batch_size=None): 
        factorial = lambda n: n if n == 1 else n * factorial(n - 1) 
        newton_bi = lambda n, k: factorial(n) // ( factorial(n - k) * factorial(k) ) 

        total_states = newton_bi(self.src_size, self.set_size) 
        if max_batch_size is None: 
            max_batch_size = total_states 
        assert total_states < 1e7, f'too many states: {total_states}' 
        # Generate all states
        indices = list()  
        visited_states = 0 
        for comb in itertools.combinations(torch.arange(self.src_size), r=self.set_size): 
            indices.append(comb) 
            visited_states += 1 
            if (visited_states % max_batch_size) == 0 or visited_states == total_states:
                stes = torch.zeros((len(indices), self.src_size), device=self.device)  
                stes[torch.arange(len(indices)).view(-1, 1).repeat(1, self.set_size), indices] = 1. 

                print(f'{visited_states}/{total_states}', max_batch_size) 
                # Update the states' attributes 
                self.batch_size = len(indices) 
                self.state = stes 
                self._update_when_batch_size_changes() 
                indices = list() 
                yield self 
    
    def _update_when_batch_size_changes(self): 
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = self.src_size * torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.ones((self.batch_size), device=self.device)
        self.is_initial = torch.zeros((self.batch_size,), device=self.device)

        # Use >= for compatiblity with `Sets` 
        self.forward_mask = 1 - (self.state >= 1.).type(self.forward_mask.dtype) 
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 

    @property 
    def unique_input(self): 
        return self.state.type(self.backward_mask.dtype)   

    @torch.no_grad() 
    def get_children(self, return_actions=False):         
        actions = torch.arange(self.src_size, device=self.device) 

        for action in actions: 
            child = copy.deepcopy(self) 
            curr_actions = action * torch.ones((self.batch_size,), device=self.device, dtype=int)  
            child.apply(
                curr_actions 
            )
            if return_actions: 
                yield child, curr_actions  
            else: 
                yield child 
    
    @torch.no_grad() 
    def get_parents(self): 
        # Each element corresponds to a parent 
        _, actions = torch.where(self.state == 1.) # curr_size * batch_size 
        curr_size = self.state.sum(dim=1)[-1].int() 

        for i in range(curr_size): 
            backward_actions = actions[torch.arange(i, actions.shape[0], step=curr_size)]   
            parent = copy.deepcopy(self) 
            parent.backward(
                backward_actions
            )
            yield parent, backward_actions  

class Bag(Set): 

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.set_size)
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.backward_mask = (self.state >= 1.).type(self.backward_mask.dtype) 
        return indices 
    
    @torch.no_grad() 
    def list_all_states(self, max_batch_size=None): 
        factorial = lambda n: n if n == 1 else n * factorial(n - 1) 
        newton_bi = lambda n, k: factorial(n) / (factorial(n - k) * factorial(k)) 
        total_num_bags = newton_bi(self.src_size + self.set_size - 1, self.set_size) 
        assert total_num_bags < 5e6  

        max_batch_size = total_num_bags if max_batch_size is None else max_batch_size 
        visited_states = list() 
        num_visited_states = 0 

        for comb in itertools.combinations_with_replacement(range(self.src_size), self.set_size): 
            visited_states.append(comb) 
            num_visited_states += 1 
            if (
                num_visited_states % max_batch_size == 0 or 
                num_visited_states + max_batch_size >= total_num_bags 
            ): 
                self.batch_size = len(visited_states) 
                self.state = torch.zeros((self.batch_size, self.src_size), device=self.device) 
                visited_states = torch.tensor(visited_states, device=self.device, dtype=torch.int64) 
                self.state.scatter_add_(dim=1, 
                                        index=visited_states, 
                                        src=torch.ones_like(visited_states).type(self.state.dtype))
                self._update_when_batch_size_changes() 
                visited_states = list() 
                yield self 