import re
from itertools import product

import numpy as np
import torch
import gym

from dflex.envs import CheetahEnv, HopperEnv, AntEnv, HumanoidEnv, CartPoleSwingUpEnv
from dgym.tl.oa import OmegaAutomaton


class DFlexLTL(gym.Wrapper):
    def __init__(self, env, signals=None, ltl=None):
        
        super().__init__(env)

        self.signals = signals
        self.ltl = ltl
        self.oa = OmegaAutomaton(self.ltl or 'true')
        print('LTL:', self.oa.ltl)
        print('Delta:', self.oa.delta)
        print('Acc:', self.oa.acc)
        self.num_oa_states = self.oa.shape[1]

        self.num_envs = self.env.num_envs
        self.num_obs = self.env.num_obs + self.num_oa_states
        self.num_actions = self.env.num_actions
        self.max_episode_length = self.env.episode_length


        self.num_aps = len(self.oa.aps)  # number of atomic propositions
        self.num_labels = len(self.oa.labels)  # number of labels := 2^num_aps

        self.observation_space = gym.spaces.Box(
            np.array([-np.inf] * self.env.num_obs + [0.0] * self.num_oa_states),
            np.array([np.inf] * self.env.num_obs + [1.0] * self.num_oa_states), 
        )
    
        number_re = r'([+-]?(?:0|[1-9]\d*)?(?:\.\d*)?(?:[eE][+\-]?\d+)?)'

        # Regular expression for the atomic propositions
        ap_re = re.compile(r'^"\s*('+'|'.join(self.signals)+r')\s*([><])\s*'+number_re+r'\s*"$')
        
        attrs_to_ids = {attr: attr_id for attr_id, attr in enumerate(self.signals)}
        ap_shape = (self.num_aps,)
        ap_ids_to_obs_ids = np.zeros(ap_shape, dtype=int)  # the corresponding obs_attr for each AP
        ap_signs = np.zeros(ap_shape, dtype=np.float32)  # the inequality symbol used for each AP; <: -1, >: 1
        ap_bounds = np.zeros(ap_shape, dtype=np.float32)  # the lower of the upper bound for each AP
        for ap_id, ap in enumerate(self.oa.aps):
            matches = re.findall(ap_re, ap)
            assert len(matches)==1 and len(matches[0])==3,  ap + ' is not in the format "[obs_attr] [>|<] [number]"'
            attr, sign, number = matches[0]
            ap_ids_to_obs_ids[ap_id] = attrs_to_ids[attr]
            ap_signs[ap_id] = 1 if sign=='>' else -1
            ap_bounds[ap_id] = float(matches[0][2]) 

        self.ap_ids_to_obs_ids = torch.from_numpy(ap_ids_to_obs_ids).to(self.env.device)
        self.ap_signs = torch.from_numpy(ap_signs).to(self.env.device)
        self.ap_bounds = torch.from_numpy(ap_bounds).to(self.env.device)

        ap_label_indicators = product([0, 1], repeat=self.num_aps)
        ap_label_indicators = sorted(ap_label_indicators, key=lambda x: 2**self.num_aps*sum(x) + sum([2**i*xi for i,xi in enumerate(x)]))
        self.ap_label_indicators = torch.tensor(list(ap_label_indicators)).unsqueeze(0).to(self.env.device)

        tensor_shape = self.num_labels, self.num_oa_states, self.num_oa_states
        transition_tensor = np.zeros(tensor_shape, dtype=np.float32)
        reward_tensor = np.zeros(tensor_shape, dtype=np.float32)
        for q in range(self.num_oa_states):    
            for q_next in range(self.num_oa_states):
                for label_id, label in enumerate(self.oa.labels):
                    if self.oa.delta[q][label] == q_next:
                        transition_tensor[label_id, q_next, q] = 1
                        reward_tensor[label_id, q_next, q] = float(self.oa.acc[q][label]%2)
        
        self.transition_tensor = torch.from_numpy(transition_tensor).to(self.env.device)
        self.reward_tensor = torch.from_numpy(reward_tensor).to(self.env.device)
        print('Reward tensor:', self.reward_tensor)

        oa_state = np.zeros((self.num_envs, self.num_oa_states), dtype=np.float32)
        oa_state[:, self.oa.q0] = 1.0
        self.oa_state = torch.tensor(oa_state).unsqueeze(2).to(self.env.device)


    def calculate_oa_reward(self, observation):
        observation[:, 0] *= 10
        observation = observation[:, self.ap_ids_to_obs_ids]
        diff = 10. * self.ap_signs * (observation - self.ap_bounds)
        ap_probs = torch.sigmoid(diff).unsqueeze(1)

        
        label_probs = torch.prod(
            ap_probs*self.ap_label_indicators + (1-ap_probs)*(1-self.ap_label_indicators),
            dim=-1
        )

        transition_matrix = torch.tensordot(label_probs, self.transition_tensor, dims=1)
        reward_matrix = torch.tensordot(label_probs, self.reward_tensor, dims=1)

        reward = torch.sum(torch.bmm(reward_matrix, self.oa_state).squeeze(-1), dim=-1)
        self.oa_state = torch.bmm(transition_matrix, self.oa_state)


        return reward


    def step(self, actions, play=False):

        if self.env.nan_state_fix:
            def create_hook():
                def hook(grad):
                    torch.nan_to_num(grad, 0.0, 0.0, 0.0, out=grad)
                return hook
            if self.oa_state.requires_grad:
                self.oa_state.register_hook(create_hook())
        
        observation, reward, done, extras = self.env.step(actions, play)
        
        extras['obs_before_reset'] =  torch.cat((extras['obs_before_reset'], self.oa_state.clone().squeeze(-1)), dim=-1)
        oa_reward = 10. * self.calculate_oa_reward(observation)
        augmented_observation = torch.cat((observation, self.oa_state.squeeze(-1)), dim=-1)

        env_ids = done.nonzero(as_tuple=False).squeeze(-1)
        if len(env_ids) > 0:
            oa_state = np.zeros((self.num_envs, self.num_oa_states), dtype=np.float32)
            oa_state[:, self.oa.q0] = 1.0
            self.oa_state = torch.tensor(oa_state).unsqueeze(2).to(self.device)

        return augmented_observation, oa_reward, done, extras


    def reset(self, env_ids=None, grads=False, force_reset=False):
        observation = self.env.reset(env_ids, grads, force_reset)

        oa_state = np.zeros((self.num_envs, self.num_oa_states), dtype=np.float32)
        oa_state[:, self.oa.q0] = 1.0
        self.oa_state = torch.tensor(oa_state).unsqueeze(2).to(self.device)

        augmented_observation = torch.cat((observation, self.oa_state.squeeze(-1)), dim=-1)
        
        return augmented_observation


    def initialize_trajectory(self):
        self.clear_grad()
        observation = self.env.initialize_trajectory()
        augmented_observation = torch.cat((observation, self.oa_state.squeeze(-1)), dim=-1)
        return augmented_observation


    def clear_grad(self, checkpoint=None):
        if checkpoint is None:
            with torch.no_grad():
                self.oa_state = self.oa_state.clone()
            self.env.clear_grad()
        else:
            self.oa_state = checkpoint['oa_state']
            self.env.clear_grad(checkpoint)

    def get_checkpoint(self):
        checkpoint = self.env.get_checkpoint()
        checkpoint['oa_state'] = self.oa_state.clone()
        return checkpoint



class CartPole(DFlexLTL):
    def __init__(self, ltl='true', **kwargs):
        env = CartPoleSwingUpEnv(**kwargs)
        signals = [
            'position_x',
            'velocity_x',
            'sin_theta',
            'cos_theta',
            'angular_vel',
        ]
        super().__init__(env, signals, ltl)

    def calculate_oa_reward(self, observation):
        observation = observation[:, self.ap_ids_to_obs_ids]
        diff = 10. * self.ap_signs * (observation - self.ap_bounds)
        diff[:, 2:4] *= 10
        ap_probs = torch.sigmoid(diff).unsqueeze(1)
        
        label_probs = torch.prod(
            ap_probs*self.ap_label_indicators + (1-ap_probs)*(1-self.ap_label_indicators),
            dim=-1
        )

        transition_matrix = torch.tensordot(label_probs, self.transition_tensor, dims=1)
        reward_matrix = torch.tensordot(label_probs, self.reward_tensor, dims=1)

        reward = torch.sum(torch.bmm(reward_matrix, self.oa_state).squeeze(-1), dim=-1)
        self.oa_state = torch.bmm(transition_matrix, self.oa_state)

        return reward

class Hopper(DFlexLTL):
    def __init__(self, ltl='true', **kwargs):
        env = HopperEnv(**kwargs)
        signals = [
            'torso_height',
            'torso_angle',
            'thigh_angle',
            'leg_angle',
            'foot_angle',
            'torso_velocity_x',
            'torso_velocity_z',
            'torso_angular_velocity',
            'thigh_angular_velocity',
            'leg_angular_velocity',
            'foot_angular_velocity',
        ]
        super().__init__(env, signals, ltl)


class Cheetah(DFlexLTL):
    def __init__(self, ltl='true', **kwargs):
        env = CheetahEnv(**kwargs)
        signals = [
            'tip_height',
            'tip_angle',
            'back_thigh_angle',
            'back_shin_angle',
            'back_foot_angle',
            'front_thigh_angle',
            'front_shin_angle',
            'front_foot_angle',
            'tip_velocity_x',
            'tip_velocity_z',
            'tip_angular_velocity',
            'back_thigh_angular_velocity',
            'back_shin_angular_velocity',
            'back_foot_angular_velocity',
            'front_thigh_angular_velocity',
            'front_shin_angular_velocity',
            'front_foot_angular_velocity',
        ] 
        super().__init__(env, signals, ltl)


class Ant(DFlexLTL):
    def __init__(self, ltl='true', **kwargs):
        env = AntEnv(**kwargs)
        signals = [
            'torso_height',
            *['torso_angle_'+str(i) for i in range(4)],
            'torso_velocity_x',
            'torso_velocity_z',
            'torso_velocity_y',
            *['torso_angular_velocity_'+str(i) for i in range(3)],
            *['other+'+str(i) for i in range(26)],
        ]
        super().__init__(env, signals, ltl)



