import torch 
import torch.nn as nn 
import copy 
import itertools 
import pickle 
import os 
import math 

HERE = os.path.abspath(os.path.dirname(__file__)) 

from sal.utils import Environment 

ASSIGNMENT_TYPE = "partition" # 'modulo' # 'partition'  

def state_to_node(state: 'Sequences', num_compute_nodes, assignment_type: str = ASSIGNMENT_TYPE): 
    if state.curr_idx > state.max_depth:
        return state.node_indices 
    if state.curr_idx < state.max_depth: 
        return None 
    
    # Write the state as a number 
    basis = state.src_size ** torch.arange(state.seq_size, device=state.device).view(1, -1)    
    mask = (state.state != state.src_size)

    if assignment_type == 'modulo':
        state.node_indices = (
            (state.state * mask * basis) % num_compute_nodes # numerical stability  
        ).sum(dim=1) % num_compute_nodes 
    elif assignment_type == 'partition': 
        indices = (state.state * mask * basis).sum(dim=1) 
        bin_size = (state.src_size ** state.curr_idx - 1) // num_compute_nodes 
        state.node_indices = torch.minimum(
            (indices / bin_size), torch.ones_like(indices) * (num_compute_nodes - 1) 
        ).long() 

    return state.node_indices 

def sample_state_on_depth(
        state, 
        depth: int, 
        node_idx: int, 
        num_nodes: int, 
        inplace: bool = False, 
        assignment_type: str = ASSIGNMENT_TYPE
    ): 
    # Compute the largest quotient 
    if assignment_type == 'modulo': 
        largest_quotient = (state.src_size ** depth - 1 + num_nodes) // num_nodes 
        # Sample states from the corresponding computing node 
        indices = num_nodes * torch.randint(largest_quotient, (state.batch_size,), device=state.device) + node_idx 
    elif assignment_type == 'partition': 
        bin_size = (state.src_size ** depth - 1) // num_nodes 
        low, high = bin_size * node_idx, bin_size * (node_idx + 1) 
        if node_idx == num_nodes - 1: 
            high = state.src_size ** depth - 1 
        indices = torch.randint(low=low, high=high, size=(state.batch_size,), device=state.device) 

    # Find the corresponding state 
    base = state.src_size ** (torch.arange(depth, device=state.device) + 1) 
    indices = indices.view(-1, 1) % base.view(1, -1)   

    # Instantiate the matrix corresponding to the linear system 
    default_dtype = torch.get_default_dtype() 
    i, j = torch.triu_indices(depth, depth, offset=1).to(state.device)  
    linear_system_matrix = (
        state.src_size ** torch.arange(depth, device=state.device) 
    ).repeat(state.batch_size, depth, 1).to(default_dtype) 
    linear_system_matrix[:, i, j] = 0 
    
    indices = indices.to(default_dtype) 

    states = torch.linalg.solve_triangular(
        linear_system_matrix, indices.view(state.batch_size, depth, 1), upper=False  
    ).squeeze(dim=-1) 

    # Pad the sequence with the remaining states 
    padding = torch.ones(
        (state.batch_size, state.seq_size - depth), device=state.device
    ) * state.src_size 
    padded_states = torch.hstack([states, padding]).long() 

    if inplace: 
        state.state = padded_states 
        state.curr_idx += depth  
        state.max_trajectory_length = state.seq_size - depth 

    return padded_states 

class LogReward(nn.Module): 

    def __init__(self, src_size, seq_size, seed, device='cpu'): 
        super(LogReward, self).__init__() 
        self.device = device 
        self.src_size = src_size 
        self.seq_size = seq_size

        g = torch.Generator(device=self.device) 
        g.manual_seed(seed) 
        self.val = math.sqrt(2) * (torch.rand((self.src_size + 1), device=self.device, generator=g) - .5) 
        self.val[-1] = 0. # For the index corresponding to padding 
        
        g = torch.Generator(device=self.device) 
        g.manual_seed(seed + 1)
        self.pos_val = math.sqrt(2) * (torch.rand((self.seq_size), device=self.device, generator=g) - .5) 

    @torch.no_grad() 
    def forward(self, batch_state, **kwargs): 
        mask = (batch_state.state != self.src_size).long()  
        weight_m = mask * self.pos_val.view(1, self.seq_size) 
        log_rewards = (self.val[batch_state.state] * weight_m).sum(dim=1) 
        return 2*log_rewards  

class LogRewardModel(nn.Module): 

    def __init__(self, log_reward_base: LogReward, device: str = 'cpu'): 
        super(LogRewardModel, self).__init__() 
        self.log_reward = log_reward_base 
        self.device = device 

    @torch.no_grad() 
    def forward(self, batch_state, gflownets): 
        log_rewards = list() 
        num_models= len(gflownets) 
        node_indices = state_to_node(batch_state, num_models)

        log_rewards_base = self.log_reward(batch_state)         
        log_rewards_model = torch.empty_like(log_rewards_base) 

        depth_mask = torch.ones_like(log_rewards_base, dtype=bool) * (batch_state.curr_idx <= batch_state.max_idx + 1)  
            
        for model_idx in range(num_models): 
            model_mask = (node_indices == model_idx) & ~depth_mask 
            if model_mask.any(): 
                masked_states = batch_state.state.to(torch.get_default_dtype())[model_mask]
                log_rewards_model[model_mask] = gflownets[model_idx].pf.mlp_flows(masked_states).squeeze(dim=-1) 

        log_rewards_all = torch.where(depth_mask, log_rewards_base, log_rewards_model) 
        return log_rewards_all 

class LogRewardBits(nn.Module): 

    def __init__(self, src_size, seq_size, num_modes, device='cpu'):
        super(LogRewardBits, self).__init__() 
        self.src_size = src_size 
        self.seq_size = seq_size 
        self.num_modes = num_modes 
        self.device = device 

        # Sample modes uniformly at random 
        mode_components = torch.tensor([
            [0, 0, 0, 0, 0, 0, 0, 0],  # '00000000'
            [1, 1, 1, 1, 1, 1, 1, 1],  # '11111111'
            [1, 1, 1, 1, 0, 0, 0, 0],  # '11110000'
            [0, 0, 0, 0, 1, 1, 1, 1],  # '00001111'
            [0, 0, 1, 1, 1, 1, 0, 0],  # '00111100'
        ], device=device, dtype=torch.long) 
        
        num_components = seq_size // 8 
        indices = torch.randint(
            low=0, high=len(mode_components), size=(num_modes, num_components) 
        )
        
        self.modes = mode_components[indices].flatten(start_dim=1)  
        self.modes = torch.hstack(
            [
                self.modes, 
                torch.zeros((num_modes, seq_size - self.modes.shape[1]), 
                            device=device, dtype=self.modes.dtype)
            ] 
        )
        
        # pass 
    
    @torch.no_grad() 
    def forward(self, batch_state): 
        state = batch_state.state.to(dtype=torch.get_default_dtype()) 
        distance_to_modes = (
            state.view(-1, 1, self.seq_size) - self.modes.view(1, -1, self.seq_size)
        ).abs().sum(dim=2) 
        assert distance_to_modes.shape[1] == self.num_modes, distance_to_modes.shape 
        return 1 - distance_to_modes.min(dim=1).values / self.seq_size 
        

class LogRewardTFN(nn.Module): 
    
    def __init__(self, n, max_val=10, exp=3, device='cpu'):
        super(LogRewardTFN, self).__init__() 
        self.seq_size = n 
        self.device = device 
        self.max_val = max_val 
        self.exp = exp 
        assert self.seq_size in [8, 10] 

        self.sequence_to_idx = lambda sequence: (
            sequence * 4 ** torch.arange(self.seq_size, device=self.device)
        ).sum(dim=1)    

        with open(f'{HERE}/../../datasets/tfbind{n}-exact-v0-all.pkl', 'rb') as f: 
            oracle_d = pickle.load(f) 

        states, rewards = oracle_d['x'], oracle_d['y'] 
        
        if n == 10: 
            from scipy.special import expit 
            rewards = expit(rewards * 3) 

        states = torch.tensor(states, device=device) 
        
        self.scaled_oracle = torch.tensor(
            max_val * (rewards ** exp) / max(rewards ** exp), device=device  
        ).squeeze(dim=-1) 
        
        indices = self.sequence_to_idx(states) 

        self.scaled_oracle = self.scaled_oracle[indices].type(
            torch.get_default_dtype() 
        ) 

    def forward(self, batch_state): 
        if not batch_state.curr_idx == self.seq_size: 
            return torch.zeros((batch_state.batch_size,), device=batch_state.device) 
        indices = self.sequence_to_idx(batch_state.state) 
        return torch.log(self.scaled_oracle[indices] + 1e-3)  

class Sequences(Environment): 

    def __init__(self, seq_size, src_size, batch_size, log_reward, device='cpu'): 
        super(Sequences, self).__init__(batch_size, seq_size, log_reward, device) 
        self.seq_size = seq_size 
        self.src_size = src_size 
        # A token defining the EoS 
        self.state = torch.ones((self.batch_size, self.seq_size), dtype=torch.long, device=self.device) * self.src_size 

        self.curr_idx = 0 
        self.max_idx = self.seq_size - 1 
        self.max_depth = self.max_idx 
            
        self.node_indices = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device)

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[:, self.curr_idx] = indices 
        self.stopped[:] = (self.curr_idx == self.max_idx) 
        self.is_initial[:] = 0. 
        self.curr_idx += 1 

    @torch.no_grad() 
    def backward(self, indices): 
        self.curr_idx -= 1 
        forward_actions = self.state[:, self.curr_idx] 
        self.is_initial[:] = (self.curr_idx == 0) 
        self.stopped[:] = 0. 
        return forward_actions 

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 

    @property 
    def unique_input(self): 
        return self.state 

    def get(self, indices): 
        copy_self = super().get(indices) 
        copy_self.state = self.state[indices] 
        return copy_self 

    @staticmethod 
    def create_env_on_depth(config, log_reward, model_idx): 
        env = Sequences(config.seq_size, config.src_size, batch_size=128, log_reward=log_reward, device=config.device)
        sample_state_on_depth(env, depth=config.max_depth, node_idx=model_idx, num_nodes=config.num_models, inplace=True) 
        return env 
        
    @staticmethod 
    def create_env_maximum_depth(config, log_reward): 
        env = Sequences(config.seq_size, config.src_size, batch_size=128, log_reward=log_reward, device=config.device)
        env.max_idx = env.max_depth = config.max_depth    
        return env 
    
    @staticmethod 
    def create_env_for_sal(config, log_reward): 
        env = Sequences(config.seq_size, config.src_size, batch_size=128, log_reward=log_reward, device=config.device)
        env.max_depth = config.max_depth 
        return env 
    
class SequencesVocabSize(Sequences): 

    # Fixed size sequences
    def __init__(self, seq_size, src_size, vocab_size, batch_size, log_reward, device='cpu'): 
        super(SequencesVocabSize, self).__init__(seq_size, src_size, batch_size, log_reward, device) 
        self.vocab_size = vocab_size 

        self.actions = torch.cartesian_prod(
            *[torch.arange(self.src_size) for _ in range(self.vocab_size)] 
        )

        self.token_to_idx = lambda action: (
            action * (self.src_size ** torch.arange(self.vocab_size, device=self.device))
        ).sum(dim=1).long()  
        
    def apply(self, indices): 
        actions = self.actions[indices] 
        self.state[:, self.curr_idx:(self.curr_idx+self.vocab_size)] = actions  
        self.curr_idx += self.vocab_size 
        self.stopped[:] = (self.curr_idx == self.seq_size)
        self.is_initial[:] = 0. 
        
    def backward(self, indices=None):
        del indices 
        self.curr_idx -= self.vocab_size 
        forward_actions = self.token_to_idx(
            self.state[:, self.curr_idx:(self.curr_idx+self.vocab_size)] 
        )
        self.state[:, self.curr_idx:(self.curr_idx+self.vocab_size)] = self.src_size 
        self.stopped[:] = 0. 
        self.is_initial[:] = (self.curr_idx == 0.) 
        return forward_actions 

if __name__ == '__main__': 
    seq_size = 10 
    src_size = 6 
    depth = 5 
    num_compute_nodes = 8 

    # Tests 
    for node_idx in range(num_compute_nodes): 
        sequences = Sequences(seq_size, src_size, 128, None, device='cpu') 
        sequences.max_depth = depth 
        sample_state_on_depth(sequences, depth, node_idx, num_compute_nodes, inplace=True) 
        indices = state_to_node(sequences, num_compute_nodes) 
        assert (indices == node_idx).all(), (indices, node_idx) 
    
    