import torch 
import torch.nn as nn 
import copy 
import itertools 

from gfn.utils import Environment 

class LogReward(nn.Module): 

    def __init__(self, src_size, seq_size, seed, device='cpu'): 
        super(LogReward, self).__init__() 
        self.device = device 
        self.src_size = src_size 
        self.seq_size = seq_size

        g = torch.Generator(device=self.device) 
        g.manual_seed(seed) 
        self.val = 2 * torch.rand((self.src_size + 1), device=self.device, generator=g) - 1
        self.val[-1] = 0. # For the index corresponding to padding 
        
        g = torch.Generator(device=self.device) 
        g.manual_seed(seed + 1)
        self.pos_val = 2 * torch.rand((self.seq_size + 1), device=self.device, generator=g) - 1 

    @torch.no_grad() 
    def forward(self, batch_state): 
        mask = (batch_state.state != self.src_size).long()  
        log_rewards = (self.val[batch_state.state] * mask * self.pos_val).sum(dim=1) 
        return log_rewards  

class Sequences(Environment): 

    def __init__(self, seq_size, src_size, batch_size, log_reward, device='cpu'): 
        super(Sequences, self).__init__(batch_size, seq_size + 1, log_reward, device) 
        self.seq_size = seq_size 
        self.src_size = src_size 
        # A token defining the EoS 
        self.state = torch.ones((self.batch_size, self.seq_size + 1), dtype=torch.long, device=self.device) 
        self.state *= self.src_size 
        self.curr_idx = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device) 
        self.max_idx = torch.ones_like(self.curr_idx) * (self.seq_size - 1) 

        self.forward_mask = torch.ones((self.batch_size, self.src_size + 1), device=self.device)
        self.backward_mask = torch.zeros((self.batch_size, 1), device=self.device)  
        self.traj_size = torch.ones((self.batch_size,), device=self.device) 

    @torch.no_grad() 
    def get_children(self): 
        actions = torch.arange(self.src_size + 1, device=self.device)
        
        for action in actions: 
            copy_self = copy.deepcopy(self) 
            actions = action * torch.ones((self.batch_size,), device=self.device)  
            actions = actions * (1 - self.stopped) + self.src_size * self.stopped  
            actions = actions.type(torch.long) 
            copy_self.apply(actions) 
            yield copy_self 

    @torch.no_grad() 
    def apply(self, indices): 
        assert (indices <= self.src_size).all() 
        is_stop_action = (indices == self.src_size) 
        indices_non_stop = indices[~is_stop_action] 
        batchid_non_stop = self.batch_ids[~is_stop_action] 
        curridx_non_stop = self.curr_idx[~is_stop_action] 

        self.state[batchid_non_stop, curridx_non_stop] = indices_non_stop  

        # Mask actions corresponding to stopped and filled states 
        umask = self.forward_mask.clone() 
        umask[is_stop_action, :-1] = 0. 
        umask[self.curr_idx == self.seq_size - 1, :-1] = 0. 
        self.forward_mask = umask.clone() 

        self.stopped = is_stop_action.long() 
        self.curr_idx += (~is_stop_action).long() 
        self.traj_size = self.curr_idx + 1 
        self.curr_idx = torch.minimum(self.curr_idx, self.max_idx)  
        self.is_initial[:] = (self.curr_idx == 0).long() 
        self.backward_mask = (1 - self.is_initial).view(-1, 1)
        return (self.stopped < 2.) 
  
    @torch.no_grad() 
    def backward(self, indices): 
        is_non_initial = ~(self.is_initial == 1)  
        self.curr_idx -= (is_non_initial & (self.stopped != 1)).long()  

        batchid_non_initial = self.batch_ids[is_non_initial] 
        curridx_non_initial = self.curr_idx[is_non_initial] 

        forward_actions = self.state[self.batch_ids, self.curr_idx].clone() 
        self.state[batchid_non_initial, curridx_non_initial] = self.src_size         
        self.forward_mask[:, :-1] = 1 

        self.stopped[:] = 0. 
        self.is_initial = (self.state == self.src_size).all(dim=1).long() 
        self.backward_mask = (1 - self.is_initial).view(-1, 1)  
        self.traj_size = self.curr_idx + 1
        
        return forward_actions 

    @torch.no_grad() 
    def list_all_states(self, max_batch_size=None): 
        idx = 0 

        # Estimate the total number of states (twice) 
        # sequence size: d, warehouse size: w 
        # w + w^{2} + ... + w^{d} = w * (w^{d} - 1) / (w - 1) 
        total_states = self.src_size * (self.src_size ** (self.seq_size) - 1) / (self.src_size - 1)   

        if max_batch_size is None: 
            max_batch_size = total_states 
        
        # Only return complete states 
        initial_state = torch.ones((max_batch_size, self.seq_size + 1), device=self.device, dtype=int) * self.src_size 

        seen_states = 0 

        indices = list() 
        for size in range(1, self.seq_size + 1): 
            for seq in itertools.product(*[torch.arange(self.src_size) for _ in range(size)]): 
                initial_state[idx, :size] = torch.tensor(seq, dtype=self.state.dtype, device=self.state.device) 
                idx += 1 
                seen_states += 1 
                indices.append(size)  
                if idx >= max_batch_size or seen_states == total_states:
                    self.state = initial_state[:idx].clone() 
                    self.batch_size = self.state.shape[0] 
                    self.max_idx = torch.ones((self.batch_size,), device=self.device, dtype=int) * (self.seq_size - 1) 
                    self._update_when_batch_size_changes() 
                    self.curr_idx = torch.minimum(
                        torch.tensor(indices, device=self.device, dtype=int), self.max_idx)  
                    print(f'{seen_states}/{total_states}') 
                    initial_state = torch.ones_like(self.state) * self.src_size  
                    indices = list() 
                    idx = 0 
                    yield self 

    def _update_when_batch_size_changes(self): 
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.stopped = torch.ones((self.batch_size), device=self.device)
        self.is_initial = torch.zeros((self.batch_size,), device=self.device)

        self.forward_mask = torch.zeros((self.batch_size, self.src_size + 1), device=self.device)  
        self.forward_mask[:, -1] = 1.  
        self.backward_mask = (1 - self.is_initial).view(-1, 1)

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 

    @property 
    def unique_input(self): 
        return self.state 
