import base64
import time
import json
from typing import Type, AnyStr, Any

from harl import constants
import numpy as np
import dill
from dataclasses import dataclass
from dataclass_wizard import JSONWizard
from dataclass_wizard.abstractions import W
from dataclass_wizard.type_def import JSONObject, Encoder


@dataclass
class Skill(JSONWizard):

    skill_name: str
    skill_function: Any
    skill_embedding: np.ndarray
    skill_code: str
    skill_code_base64: str
    float_type: str = 'float32'
    # exec_globals: Any = globals().copy()


    def __call__(self, *args, **kwargs):
        return self.skill_function(*args, **kwargs)


    @classmethod
    def from_dict(cls: Type[W], o: JSONObject) -> W:

        skill_function = dill.loads(bytes.fromhex(o['skill_function'])) # Load skill function from hex string
        embedding_bytes = base64.b64decode(o['skill_embedding'])
        float_type = o.get('float_type', 'float32')
        if float_type == 'float32':
            skill_embedding = np.frombuffer(embedding_bytes, dtype=np.float32)
        elif float_type == 'float64':
            skill_embedding = np.frombuffer(embedding_bytes, dtype=np.float64)
        else:
            raise ValueError(f"Unsupported float type: {float_type}")

        return cls(
            skill_name=o['skill_name'],
            skill_function=skill_function,
            skill_embedding=skill_embedding,
            skill_code=o['skill_code'],
            skill_code_base64=o['skill_code_base64'],
            float_type=float_type
        )


    def to_dict(self) -> JSONObject:
        skill_function_hex = dill.dumps(self.skill_function).hex() # Convert skill function to hex string
        skill_embedding_bytes = self.skill_embedding.tobytes()
        skill_embedding_base64 = base64.b64encode(skill_embedding_bytes).decode('utf-8')
        # Check float type of self.skill_embedding
        if self.skill_embedding.dtype == np.float32:
            self.float_type = 'float32'
        elif self.skill_embedding.dtype == np.float64:
            self.float_type = 'float64'
        else:
            raise ValueError(f"Unsupported float type: {self.skill_embedding.dtype}")

        return {
            'skill_name': self.skill_name,
            'skill_function': skill_function_hex,
            'skill_embedding': skill_embedding_base64,
            'skill_code': self.skill_code,
            'skill_code_base64': self.skill_code_base64,
            'float_type': self.float_type
        }


    def to_json(self: W, *,
                encoder: Encoder = json.dumps,
                **encoder_kwargs) -> AnyStr:
        return json.dumps(self.to_dict(), **encoder_kwargs)


    @classmethod
    def from_json(cls: Type[W], s: AnyStr, *,
                  decoder: Any = json.loads,
                  **decoder_kwargs) -> W:
        return cls.from_dict(json.loads(s, **decoder_kwargs))


def post_skill_wait(wait_time = constants.DEFAULT_POST_ACTION_WAIT_TIME):
    """Wait for skill to finish. Like if there is an animation"""
    time.sleep(wait_time)
