import math
import torch
import os
import sys

import imageio
import numpy as np


class skillsampler():
    def __init__(self, env, args):
        
        self.radius_bound_init = np.array([int(item) for item in args.radius_bound.split(',')])
        self.radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])
        self.num_intervals = args.num_intervals
        self.use_adaptive_sampling = args.use_adaptive_sampling
        self.radius_input_dim = args.radius_input_dim
        self.env = env
    
    def sample(self, current_index=0, eval=False):
        radius_values = np.linspace(self.radius_bound[0], self.radius_bound[1], num=self.num_intervals, dtype=int)
        
        if eval:
            radius_value = radius_values[current_index]
        else:
            radius_value = np.random.choice(radius_values)
        
        pos_enc = positional_encoding(radius_value, self.radius_input_dim)
        
        return radius_value, pos_enc
       
    
    def update_bound(self, policy, psi):
        
        average_reward = np.zeros(2) 
        num_test_epi = 5
        
        if self.use_adaptive_sampling == False:
            return self.radius_bound, average_reward[0], average_reward[1]

        self.has_exceeded_06_lower  = getattr(self, 'has_exceeded_06_lower', False)
        self.has_exceeded_06_upper = getattr(self, 'has_exceeded_06_upper', False)

        for idx, radius in enumerate(self.radius_bound):
            sum_reward = 0  
            
            for _ in range(num_test_epi):
                state = self.env.reset()  
                radius_value = radius  
                radius_input = positional_encoding(radius_value, self.radius_input_dim)  
                state = np.concatenate([state, radius_input])  
                
                episode_reward = 0 
                done = False
                
                while not done:
                    action = policy.select_action(state, evaluate=True) 
                    next_state, reward, done, _ = self.env.step(action)  
                    next_state = np.concatenate([next_state, radius_input]) 
                    
                    psi_diff = psi.forward_np(next_state) - psi.forward_np(state)
                    pseudo_reward = np.exp(-10 * np.linalg.norm(radius_value * np.sin(np.pi / (2 * radius_value)) - np.linalg.norm(psi_diff))**2)
                    
                    episode_reward += pseudo_reward 
                    state = next_state  
                
                sum_reward += episode_reward 
            
            average_reward[idx] = sum_reward / num_test_epi
        
        threshold = 0.9 * self.env._max_episode_steps
        
        if average_reward[0] > threshold:
            self.has_exceeded_06_lower = True  
            self.radius_bound[0] -= 1
            self.radius_bound[0] = max(self.radius_bound[0], 5)  
        
        if average_reward[1] > threshold:
            self.has_exceeded_06_upper = True  
            self.radius_bound[1] += 1
        
        if self.has_exceeded_06_lower and average_reward[0] < 0.4 * self.env._max_episode_steps:
            self.radius_bound[0] += 1
            self.radius_bound[0] = min(self.radius_bound[0], self.radius_bound_init[0]) 
        
        if self.has_exceeded_06_upper and average_reward[1] < 0.4 * self.env._max_episode_steps:
            self.radius_bound[1] -= 1
            self.radius_bound[1] = max(self.radius_bound[1], self.radius_bound_init[1]) 
        
        return self.radius_bound, average_reward[0], average_reward[1]
        
        
def generate_random_radius(radius_bound, dim, num_intervals, current_index = 0, eval=False):

    radius_values = np.linspace(radius_bound[0], radius_bound[1], num=num_intervals, dtype=int)
    if eval:
        radius_value = radius_values[current_index]
    else:
        radius_value = np.random.choice(radius_values)
    
    pos_enc = positional_encoding(radius_value, dim)
    return radius_value, pos_enc


def positional_encoding(R, dim):
    pos_enc = np.zeros(dim)
    position = np.arange(dim)
    div_term = np.power(10000.0, -2 * (position // 2) / dim)
    
    pos_enc[0::2] = np.sin(R * div_term[0::2])
    pos_enc[1::2] = np.cos(R * div_term[1::2])
    
    return pos_enc


def compute_cosine_weight(x, N, max_weight):
    if x <= N:
        return (1 - np.cos(torch.pi * x / N)) / 2 * max_weight
    else:
        return 1.0 * max_weight


def generate_skill_disc(dim, eval_idx = -1):

    vector = np.full(dim, 0)

    if eval_idx != -1:
        vector[eval_idx] = 1
    else:
        idx = np.random.randint(dim)
        vector[idx] = 1
    
    return vector


def generate_skill_cont(dim):

    while True:
        vector = np.random.normal(0, 1, dim)
        norm = np.linalg.norm(vector)
        if norm > 1e-6:
            break

    normalized_vector = vector / norm
    
    return normalized_vector


def normalize_vector(v):
    norm = np.linalg.norm(v)
    if norm == 0: 
        return v
    return v / norm


def add_noise_to_skill(skill, noise_scale, step):
    if step % 4 in [0, 1]:
        noisy_skill = skill + (noise_scale, -noise_scale)
    else:
        noisy_skill = skill + (-noise_scale, noise_scale)
    return normalize_vector(noisy_skill)

