import torch

import numpy as np
import gymnasium as gym
import pickle
import pandas as pd

# from sehstr import gbr_proxy
# from rdkit import Chem
# from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
# from rdkit.DataStructs import FingerprintSimilarity

class Environment(gym.Env):
    metadata = {"render_modes": []}

    def __init__(self, data_folder = 'gym_envs/sehstr'):
        """SEHstringMDP environment from the GFlowNets paper. Modified to make it sparse."""
        self.__init_from_blocks_file(f'{data_folder}/block_18.json')
        symbols = '0123456789abcdefghijklmnopqrstuvwxyz' + \
              'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\()*+,-./:;<=>?@[\]^_`{|}~'
        assert len(self.blocks) <= len(symbols)
        self.alphabet = symbols[:len(self.blocks)]
        self.alphabet_set = set(self.alphabet)       
        self.force_stop_len = 6

        self.observation_space = gym.spaces.Box(shape=(len(self.alphabet) * self.force_stop_len,), low=0, high=1, dtype=np.float32)
        
        # The actions are the add alphabet (there are len(self.alphabet)) of them) in the beginning, add alphabet in the end, and stop
        self.action_space = gym.spaces.Discrete(2*len(self.alphabet) + 1)

        # self.proxy_model = gbr_proxy.sEH_GBR_Proxy()

        with open(f'{data_folder}/sehstr_gbtr_allpreds.pkl', 'rb') as f:
            self.rewards = pickle.load(f)

        # scale rewards
        py = np.array(list(self.rewards))
        self.SCALE_REWARD_MAX = 10
        self.SCALE_MIN = 1e-3
        self.REWARD_EXP = 6

        py = np.maximum(py, self.SCALE_MIN)
        py = py ** self.REWARD_EXP
        self.scale = self.SCALE_REWARD_MAX / max(py)
        py = py * self.scale

        self.scaled_rewards = py

        # plt the figure of the reward dist
        # import matplotlib.pyplot as plt
        # plt.hist(self.scaled_rewards, bins=100)
        # plt.show()

        # define modes as top % of xhashes.
        mode_percentile = 0.001
        self.mode_r_threshold = np.percentile(py, 100*(1-mode_percentile))

        # make the reward sparser
        self.scaled_rewards[py < self.mode_r_threshold] = 1e-10
        # import matplotlib.pyplot as plt
        # plt.hist(self.scaled_rewards, bins=100)
        # plt.show()

        # init the state
        self.state = ''
        self.powers = np.array([len(self.alphabet)**i for i in range(self.force_stop_len)])

    def __init_from_blocks_file(self, blocks_file):
        self.blocks = pd.read_json(blocks_file)
        self.block_smi = self.blocks['block_smi'].to_list()
        self.block_rs = self.blocks['block_r'].to_list()
        self.block_nrs = np.asarray([len(r) for r in self.block_rs])
        
        assert all(nr == 2 for nr in self.block_nrs)
        return
    
    def reset(self, seed=None, options = None):
        super().reset(seed=seed)
        
        self.state = ''

        return self._get_obs(), {}

    def step(self, action):
        """`action` is an integer in {0, ..., height}.
        """
        # if action is exit or if going out of bounds, both terminate the episode without changing the state
        if action == 2*len(self.alphabet): 
            digits = np.argmax(self._get_obs().reshape(6, 18), axis=1) 
            index = int(np.dot(digits, self.powers))
            return self._get_obs(), self.scaled_rewards[index], True, False, {'augmented_rew': 1}
        elif action < len(self.alphabet):
            self.state = self.alphabet[action] + self.state
        else:
            self.state += self.alphabet[action - len(self.alphabet)]
        
        return self._get_obs(), 0, False, False, {}
    
    def symbol_ohe(self, symbol):
        zs = np.zeros(len(self.alphabet))
        zs[self.alphabet.index(symbol)] = 1.0
        return zs

    def _get_obs(self):
        x_ft = np.concatenate([self.symbol_ohe(c) for c in self.state] + 
                              [np.zeros(len(self.alphabet))]*(self.force_stop_len - len(self.state)))
        return x_ft.astype(np.float32)
                
    def get_state(self, obs):
        # if obs.ndim == 1:
        #     obs = obs.reshape(1, -1)
        
        # n_samples, n_features = obs.shape
        # n_symbols = n_features // self.num_blocks

        # # Reshape each sample into (n_symbols, num_blocks)
        # reshaped = obs.reshape(n_samples, n_symbols, self.num_blocks)
        # # Find the index of the one-hot element along each block
        # indices = np.argmax(reshaped, axis=2)  # shape: (n_samples, n_symbols)
        
        # # Convert indices to corresponding symbols for each sample
        # states = [[self.symbols[idx] for idx in sample] for sample in indices]
        # return states
        return obs
    
    def get_forward_action_masks(self, state):
        if isinstance(state, torch.Tensor):
            selected_ind2 = torch.sum(state, dim=-1) == self.force_stop_len  # maximum reached
            action_dim = 2 * len(self.alphabet) + 1
            # Create a mask tensor with shape state.shape[:-1] + (action_dim,)
            masks = torch.ones(state.size()[:-1] + (action_dim,), device=state.device, dtype=torch.bool)
            # When force_stop condition is met, disable all actions except force-stop (last index)
            masks[..., :-1] &= ~selected_ind2.unsqueeze(-1)
        else:
            selected_ind2 = np.sum(state, axis=-1) == self.force_stop_len
            action_dim = 2 * len(self.alphabet) + 1
            masks = np.ones(state.shape[:-1] + (action_dim,), dtype=bool)
            masks[selected_ind2, :-1] = False

        return masks
        
    def get_backward_action_masks(self, state):
        # Compute a boolean mask where exactly one backward action is active.
        action_dim = 2 * len(self.alphabet)
        masks = torch.zeros(state.shape[:-1] + (action_dim,), device=state.device, dtype=torch.bool)

        leftmost_indices = (state == 1).float().argmax(dim=-1) % len(self.alphabet)
        rightmost_indices = (state.size(-1) - 1 - (state.flip(dims=[-1]) == 1).float().argmax(dim=-1)) % len(self.alphabet)

        # Now assign True only at the argmax and argmin index for those entries.
        # Use scatter_ to set the specific positions to True.
        # leftmost_indices has shape [batch, traj], so we unsqueeze to shape [batch, traj, 1] for scatter.
        masks = masks.scatter_(-1, leftmost_indices.unsqueeze(-1), True)
        masks = masks.scatter_(-1, (rightmost_indices + len(self.alphabet)).unsqueeze(-1), True)

        return masks

    def get_error(self, samples):
        """Get the pearson correlation between the distribution given by `samples` and the true distribution
        """
         # Reshape to (N, 6, 18)
        reshaped = samples.reshape(-1, 6, 18)
        # # Use np.argmax along the last axis to get the digit for each block (shape: (N, 6))
        digits = np.argmax(reshaped, axis=2)
        # # Powers of 18 for each of the 6 positions
        powers = np.array([18**i for i in range(6)])
        # # Compute indices via dot product along axis=1
        indices = digits.dot(powers) 

        sample_dist = np.zeros(self.scaled_rewards.shape)
        for i in indices:
            sample_dist[i] += 1

        return np.corrcoef(self.scaled_rewards, sample_dist)[0][1]
        # reshaped = samples.reshape(-1, 6, 18)
        # # Use np.argmax along the last axis to get the digit for each block (shape: (N, 6))
        # digits = np.argmax(reshaped, axis=2)
        # # Powers of 18 for each of the 6 positions
        # powers = np.array([18**i for i in range(6)])
        # # Compute indices via dot product along axis=1
        # indices = digits.dot(powers) 

        # visited = np.zeros_like(self.scaled_rewards, dtype=bool)
        # visited[np.unique(indices)] = True
        # high_reward_mask = self.scaled_rewards > 1e-10
        # num_high_reward_visited = np.sum(visited & high_reward_mask)

        # # return coverage ratio
        # return num_high_reward_visited / np.sum(high_reward_mask)

    def render(self):
        pass

if __name__ == "__main__":
    # =========================================================================
    # Show sparsity gym
    np.random.seed(42)
    from tqdm import trange
    num_samples =  100_000
    env = Environment(data_folder='sehstr')
    rs = []
    for _ in trange(num_samples):
        done = False
        while not done:
            valid_actions = env.get_forward_action_masks(env._get_obs())
            action_index = np.random.choice(int(np.sum(valid_actions)))
            action = np.where(valid_actions==1)[0][action_index]
            _, r, done, _, _ = env.step(action)
        rs.append(r)
        env.reset()
    
    rs = np.array(rs)
    print(np.sum(env.unwrapped.scaled_rewards>1e-3))
    print(np.sum(rs > 1e-3))