import math
import torch
import os
import sys

import imageio
import numpy as np


class skillsampler():
    def __init__(self, env, args):
        
        self.radius_bound_init = np.array([int(item) for item in args.radius_bound.split(',')])
        self.radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])
        self.num_intervals = args.num_intervals
        self.use_adaptive_sampling = args.use_adaptive_sampling
        self.radius_input_dim = args.radius_input_dim
        self.metra_dim = args.metra_skill_dim
        self.pos_dim = args.pos_dim
        self.env = env
    
    def sample(self, current_index=0, eval=False):
        radius_values = np.linspace(self.radius_bound[0], self.radius_bound[1], num=self.num_intervals, dtype=int)
        
        if eval:
            radius_value = radius_values[current_index]
        else:
            radius_value = np.random.choice(radius_values)
        
        pos_enc = positional_encoding(radius_value, self.radius_input_dim)
        
        return radius_value, pos_enc
       
    
    def update_bound(self, policy, psi):
        
        average_reward = np.zeros(2) 
        num_test_epi = 5

        if self.use_adaptive_sampling == False:
            return self.radius_bound, average_reward[0], average_reward[1]

        self.has_exceeded_06_lower  = getattr(self, 'has_exceeded_06_lower', False)
        self.has_exceeded_06_upper = getattr(self, 'has_exceeded_06_upper', False)

        for idx, radius in enumerate(self.radius_bound):
            sum_reward = 0  
            
            for _ in range(num_test_epi):
                state = self.env.reset()  
                radius_value = radius  
                radius_input = positional_encoding(radius_value, self.radius_input_dim)  
                metra_skill = generate_skill_cont(self.metra_dim)
                
                state = np.concatenate([state, radius_input, metra_skill])
                
                episode_reward = 0 
                done = False
                
                while not done:
                    action = policy.select_action(state, evaluate=True) 
                    next_state, reward, done, _ = self.env.step(action)  
                    next_state = np.concatenate([next_state, radius_input, metra_skill])
                    
                    psi_diff = psi.forward_np(next_state[self.pos_dim:]) - psi.forward_np(state[self.pos_dim:])
                    pseudo_reward = np.exp(-10 * np.linalg.norm(radius_value * np.sin(np.pi / (2 * radius_value)) - np.linalg.norm(psi_diff))**2)
                    
                    episode_reward += pseudo_reward 
                    state = next_state  
                
                sum_reward += episode_reward 
            
            average_reward[idx] = sum_reward / num_test_epi
        
        threshold = 0.9 * self.env._max_episode_steps
        
        if average_reward[0] > threshold:
            self.has_exceeded_06_lower = True  
            self.radius_bound[0] -= 1
            self.radius_bound[0] = max(self.radius_bound[0], 5)  
        
        if average_reward[1] > threshold:
            self.has_exceeded_06_upper = True  
            self.radius_bound[1] += 1
        
        if self.has_exceeded_06_lower and average_reward[0] < 0.4 * self.env._max_episode_steps:
            self.radius_bound[0] += 1
            self.radius_bound[0] = min(self.radius_bound[0], self.radius_bound_init[0]) 
        
        if self.has_exceeded_06_upper and average_reward[1] < 0.4 * self.env._max_episode_steps:
            self.radius_bound[1] -= 1
            self.radius_bound[1] = max(self.radius_bound[1], self.radius_bound_init[1]) 
        
        return self.radius_bound, average_reward[0], average_reward[1]
        
        
def generate_skill_disc(dim, eval=-1):
    """
    Generate a discrete skill vector of unit norm.
    If eval is 0 or 1, generate a fixed vector with negative or positive sign respectively.
    If eval == -1, sample sign randomly.
    """
    if eval == 0:
        sign = -1
    elif eval == 1:
        sign = 1
    else:
        sign = 1 if np.random.rand() < 0.5 else -1

    value = sign / np.sqrt(dim)
    vector = np.full(dim, value)
    
    return vector

def generate_skill_disc_theta(dim=2, eval=-1):
    # Divide the full circle into `dim` equally spaced angles
    num_interval = 16
    angles = np.linspace(0, 2 * np.pi, num_interval, endpoint=False)

    # Use eval as index if it's provided (i.e., not -1)
    if eval != -1:
        idx = eval % num_interval  # 
    else:
        idx = np.random.randint(num_interval)

    # Compute the unit vector corresponding to the selected angle
    vector = np.array([np.cos(angles[idx]), np.sin(angles[idx])])

    return vector

def generate_skill_cont(dim):

    while True:
        vector = np.random.normal(0, 1, dim)
        norm = np.linalg.norm(vector)
        if norm > 1e-6:
            break

    normalized_vector = vector / norm
    
    return normalized_vector


def normalize_vector(v):
    norm = np.linalg.norm(v)
    if norm == 0: 
        return v
    return v / norm


def add_noise_to_skill(skill, noise_scale, step):
    if step % 4 in [0, 1]:
        noisy_skill = skill + (noise_scale, -noise_scale)
    else:
        noisy_skill = skill + (-noise_scale, noise_scale)
    return normalize_vector(noisy_skill)



def generate_random_radius(radius_bound, dim, num_intervals, current_index = 0, eval=False):

    radius_values = np.linspace(radius_bound[0], radius_bound[1], num=num_intervals, dtype=int)
    if eval:
        radius_value = radius_values[current_index]
    else:
        radius_value = np.random.choice(radius_values)
    
    pos_enc = positional_encoding(radius_value, dim)
    return radius_value, pos_enc


def positional_encoding(R, dim):
    pos_enc = np.zeros(dim)
    position = np.arange(dim)
    div_term = np.power(10000.0, -2 * (position // 2) / dim)
    
    pos_enc[0::2] = np.sin(R * div_term[0::2])
    pos_enc[1::2] = np.cos(R * div_term[1::2])
    
    return pos_enc

def compute_scheduled_weight(cur_epi, sat_episode, start_weight, end_weight):
    if cur_epi >= sat_episode:
        return end_weight
    else:
        weight = start_weight + (end_weight - start_weight) * (cur_epi / sat_episode)
        return weight
