import numpy as np

# output function of ChatGPT: moving North checking. state-based: cheetah, ant
class GPTOutputSkiller():
    def __init__(self, input_type='state'):
        self.obs_to_goal_dim = [0,2]
        self.input_type = input_type
        if input_type != 'state':
            raise NotImplementedError

    def label_states(self, states):
        # return labels of given states.
        # states: (traj_batch_size, eps_length, state_dim)
        # return: (traj_batch_size, eps_length), a lable for each state.
        '''
        # ant: North
        all_skills_lables = np.zeros((len(states), len(states[0])))
        for n_batch in range(len(states)):
            y_coords = states[n_batch][:, 1]
            north_movement = np.where(y_coords > 0, 1, 0)  
            all_skills_lables[n_batch, :] = north_movement
        return all_skills_lables

        # ant: flip
        all_skills_lables = np.zeros((len(states), len(states[0])))
        for n_batch in range(len(states)):
            qx = states[n_batch][:, 3]
            qy = states[n_batch][:, 4]
            qz = states[n_batch][:, 5]
            qw = states[n_batch][:, 6]
            sinp = 2.0 * (qw * qy - qz * qx) # Calculate pitch angle from quaternion
            flipped = np.where(np.abs(sinp) > 0.99, 0, 1) # Check if pitch angle is close to ±90 degrees (π/2 radians) and Allow some tolerance
            all_skills_lables[n_batch, :] = flipped
        return all_skills_lables

        '''
        # cheetah flip
        all_skills_lables = np.zeros((len(states), len(states[0])))
        for n_batch in range(len(states)):

            front_tip_angle = states[n_batch][:, 2]
            #second_rotor_angle = states[n_batch][:, 4]
            #second_rotor_angle = states[n_batch][:, 3]

            angle_threshold = 1.57  # Approximately 90 degrees in radians

            #front_flip = np.where(np.abs(front_tip_angle) > angle_threshold, 0, 1)
            #back_flip = np.where(np.abs(second_rotor_angle) > angle_threshold, 0, 1)

            #flipped = np.logical_or(front_flip, back_flip).astype(int)

            flipped = np.where(np.abs(front_tip_angle) > angle_threshold, 0, 1)

            all_skills_lables[n_batch, :] = flipped
        return all_skills_lables

from fm import CLIP
class CLIPSkillerPixel():
    def __init__(self, input_type='pixel', reward_type='onehot', start_coef=0.1, end_coef=0.1, decay_rate=0.001):
        self.input_type = input_type
        if input_type != 'pixel':
            raise NotImplementedError
        self.skill_reward_type = reward_type
        self.fm = CLIP(self.skill_reward_type)
        
        # coef
        self.start_coef = start_coef
        self.end_coef = end_coef
        self.decay_rate = decay_rate # 0.005 will decay to end_coef in ~1200 steps.
        self.decay_cntr = 0

    def label_states(self, states):
        # return labels of given states.
        # states: (traj_batch_size, eps_length, state_dim)
        # return: (traj_batch_size, eps_length), a lable for each state.

        # move North by Claude
        all_skills_lables = np.zeros((len(states), len(states[0])))
        img_shape = (64, 64, 3)
        coef = self.get_coef()
        for n_batch in range(len(states)):
            # Get the average color of the bottom 10% of the array
            # This focuses on the area just beneath the robot
            batch_states = states[n_batch].reshape(-1, 3, 64, 64, 3) # 3 is frame_stack=3
            batch_states = batch_states[:, 1, :, :, :] # only take the 2nd frame 

            fm_output = self.fm.get_batch_answers(batch_states)
            if self.skill_reward_type == 'onehot':
                #fm_reward = np.where(fm_output == 1, 1, 0.1)
                #fm_reward = np.where(fm_output == 1, 1, 0.8)
                fm_reward = np.where(fm_output == 1, 1, 0)
                #fm_reward = np.where(fm_output == 1, 1, coef)
            elif self.skill_reward_type == 'prob':
                fm_reward = fm_output
            else:
                raise NotImplementedError

            all_skills_lables[n_batch, :] = fm_reward
        return all_skills_lables

    def get_coef(self):
        value = self.end_coef + (self.start_coef - self.end_coef) * np.exp(-self.decay_rate * self.decay_cntr)
        self.decay_cntr += 1
        return value 
