import time
import numpy
import re
from numpy import random

def pseudo_random_seed(hyperseed=0):
    '''
    Generate a pseudo random seed based on current time and system random number
    '''
    timestamp = time.time_ns()
    system_random = int(random.random() * 100000000)
    pseudo_random = timestamp + system_random + hyperseed
    
    return pseudo_random % (4294967296)

class RandomFourier(object):
    def __init__(self,
                 ndim,
                 max_order=16,
                 max_item=5,
                 max_steps=1000,
                 box_size=2):
        n_items = random.randint(1, max_item + 1)
        self.coeffs = [(0, random.normal(size=(ndim, 2)) * random.exponential(scale=box_size / numpy.sqrt(n_items), size=(ndim, 2)))]
        self.max_steps = max_steps
        for j in range(n_items):
            # Sample a cos nx + b cos ny
            order = random.randint(1, max_order + 1) + random.normal(scale=1.0)
            factor = random.normal(size=(ndim, 2)) * random.exponential(scale=box_size / numpy.sqrt(n_items), size=(ndim, 2))
            self.coeffs.append((order, factor))

    def __call__(self, t):
        # calculate a cos nx + b cos ny with elements of [t, [a, b]]
        x = t / self.max_steps
        y = 0
        for order, coeff in self.coeffs:
            y += coeff[:,0] * numpy.sin(order * x) + coeff[:,1] * numpy.cos(order * x)
        return y

    def __init__(self,
                 ndim,
                 type='static',
                 reward_type='p',
                 repetitive_position=None,
                 repetitive_distance=0.2,
                 is_pitfall=False,
                 max_try=10000,
                 box_size=2):
        # Type: static, fourier
        # Reward type: field (f), trigger (t), potential (p) or combination (e.g., `ft`, `pt`)
        # Pitfall: if True, the goal is a pitfall, otherwise it is a goal
        eff_factor = numpy.sqrt(ndim)
        eff_rd = repetitive_distance * eff_factor
        self.reward_type = reward_type
        self.is_pitfall = is_pitfall
        if(type == 'static'):
            overlapped = True
            ntry = 0
            while overlapped and ntry < max_try:
                position = random.uniform(low=-box_size, high=box_size, size=(ndim, ))

                overlapped = False
                
                if(repetitive_position is None):
                    break

                for pos in repetitive_position:
                    dist = numpy.linalg.norm(pos - position)
                    if(dist < eff_rd):
                        overlapped = True
                        break
                ntry += 1
            if(ntry >= max_try):
                raise RuntimeError(f"Failed to generate goal position after {max_try} tries.")
            self.position = lambda t:position
        elif(type == 'fourier'):
            self.position = RandomFourier(ndim)
        else:
            raise ValueError(f"Invalid goal type: {type}")
        self.activate()

        self.has_field_reward=False
        self.has_trigger_reward=False
        self.has_potential_reward=False

        if('f' in self.reward_type): # Field Rewards
            self.field_reward = random.uniform(0.2, 0.8)
            self.field_threshold = random.exponential(box_size / 2) * eff_factor
            self.has_field_reward = True
        if('t' in self.reward_type): # Trigger Rewards
            self.trigger_reward = max(random.exponential(5.0), 1.0)
            self.trigger_threshold = random.uniform(0.20, 0.50) * eff_factor
            if(is_pitfall):
                self.trigger_threshold += box_size / 4
            self.trigger_rs_terminal = self.trigger_reward
            self.trigger_rs_threshold = 3 * box_size * eff_factor
            self.trigger_rs_potential = self.trigger_reward * self.trigger_rs_threshold / box_size
            self.has_trigger_reward = True
        if('p' in self.reward_type): # Potential Rewards
            self.potential_reward = max(random.exponential(2.0), 0.5)
            self.potential_threshold = random.uniform(box_size/2, box_size) * eff_factor
            self.has_potential_reward = True

    def activate(self):
        self.is_activated = True

    def deactivate(self):
        self.is_activated = False

    def __call__(self, sp, sn, t=0, need_reward_shaping=False):
        # input previous state, next state        
        # output reward, done
        if(not self.is_activated):
            return 0.0, False, {}
        reward = 0
        shaped_reward = 0
        done = False
        cur_pos = self.position(t)
        dist = numpy.linalg.norm(sn - cur_pos)
        distp = numpy.linalg.norm(sp - cur_pos)
        if(self.has_field_reward):
            if(dist <= 3.0 * self.field_threshold):
                k = dist / self.field_threshold
                reward += self.field_reward * numpy.exp(- k ** 2)
        if(self.has_trigger_reward):
            if(dist <= self.trigger_threshold):
                reward += self.trigger_reward
                if(need_reward_shaping):
                    shaped_reward += self.trigger_rs_terminal - self.trigger_reward
                done = True
            if(need_reward_shaping):
                if(dist <= self.trigger_rs_threshold):
                    shaped_reward += self.trigger_rs_potential * (min(distp, self.trigger_rs_threshold) - dist) / self.trigger_rs_threshold
            #print(f"dist: {dist}, distp: {distp}, reward: {shaped_reward}, \
            #      trigger_rs_threshold: {self.trigger_rs_threshold}")
        if(self.has_potential_reward):
            if(dist <= self.potential_threshold):
                reward += self.potential_reward * (min(distp, self.potential_threshold) - dist) / self.potential_threshold
        shaped_reward += reward
        if(self.is_pitfall):
            reward *= -1
            shaped_reward = 0
        return reward, done, {'shaped_reward':shaped_reward}