import numpy as np
from gym.envs.mujoco import Walker2dEnv as WalkerEnv_
from . import register_env
import inspect
import os
import tempfile
import xml.etree.ElementTree as ET
import mujoco_py
@register_env('walker-rand-params')
class MultitaskWalkerEnv(WalkerEnv_):
    def __init__(self, **kwargs):
        super(MultitaskWalkerEnv,self).__init__()
        self.log_scale_limit = 3.0
        self.rand_params = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction']
        hopper_env_file_path = inspect.getfile(WalkerEnv_)
        mujoco_envs_dir = os.path.dirname(hopper_env_file_path)
        self.model_path = os.path.join(mujoco_envs_dir, 'assets', 'walker2d.xml')
        
        self.save_parameters()
        self.tasks = []
        self._goal = []

    def _create_modified_xml(self, task_params):
        """
        주어진 파라미터 딕셔너리로 원본 XML을 수정하여
        임시 파일에 저장하고 그 경로를 반환합니다.
        """
        tree = ET.parse(self.model_path)
        root = tree.getroot()

        # === 모든 파라미터 수정 로직 ===

        if 'body_mass' in task_params:
            mass_array = task_params['body_mass']
            for i, body in enumerate(root.findall(".//body[@name]")):
                body_name = body.get('name')
                body_id = self.model.body_name2id(body_name)
                geom = body.find('geom')
                if geom is not None:
                    # mass_array의 인덱스는 body_id와 일치합니다.
                    geom.set('mass', str(mass_array[body_id]))

        if 'dof_damping' in task_params:
            damping_array = task_params['dof_damping']
            for i, joint in enumerate(root.findall(".//joint[@name]")):
                # joint(dof)의 순서는 XML에 정의된 순서와 같습니다.
                joint.set('damping', str(damping_array[i]))

        if 'body_inertia' in task_params:
            inertia_array = task_params['body_inertia']
            for i, body in enumerate(root.findall(".//body[@name]")):
                body_name = body.get('name')
                body_id = self.model.body_name2id(body_name)
                inertial_element = body.find('inertial')
                
                if inertial_element is not None:
                    # inertia_array의 인덱스는 body_id와 일치합니다.
                    inertia_val = inertia_array[body_id]
                    # diaginertia는 "ixx iyy izz" 형식의 문자열입니다.
                    inertial_element.set('diaginertia', f"{inertia_val[0]} {inertia_val[1]} {inertia_val[2]}")

        # if 'geom_friction' in task_params:
        #     friction_array = task_params['geom_friction']
        #     for i, geom in enumerate(root.findall(".//geom[@name]")):
        #         # geom의 순서는 XML에 정의된 순서와 같습니다.
        #         friction_val = friction_array[i]
        #         # friction은 "sliding torsional rolling" 형식의 문자열입니다.
        #         geom.set('friction', f"{friction_val[0]} {friction_val[1]} {friction_val[2]}")


        # 수정된 XML을 임시 파일에 저장
        fd, temp_path = tempfile.mkstemp(suffix='.xml', text=True)
        tree.write(temp_path)
        os.close(fd)
        
        return temp_path


    def sample_tasks(self, n_tasks):
        param_sets = []

        for _ in range(n_tasks):
            # body mass -> one multiplier for all body parts

            new_params = {}
            if 'body_mass' in self.rand_params:
                body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_mass.shape)
                new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers

            # body_inertia
            if 'body_inertia' in self.rand_params:
                body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_inertia.shape)
                new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia']

            # damping -> different multiplier for different dofs/joints
            if 'dof_damping' in self.rand_params:
                dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.dof_damping.shape)
                new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers)

            # friction at the body components
            # if 'geom_friction' in self.rand_params:
            #     dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.geom_friction.shape)
            #     new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers)

            param_sets.append(new_params)
        return param_sets
    
    def sample_train_tasks(self, n_tasks):
        param_sets = []

        for _ in range(n_tasks):
            # body mass -> one multiplier for all body parts

            new_params = {}
            if 'body_mass' in self.rand_params:
                body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, 0.7*self.log_scale_limit,  size=self.model.body_mass.shape)
                new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers

            # body_inertia
            if 'body_inertia' in self.rand_params:
                body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, 0.7*self.log_scale_limit,  size=self.model.body_inertia.shape)
                new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia']

            # damping -> different multiplier for different dofs/joints
            if 'dof_damping' in self.rand_params:
                dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-self.log_scale_limit, 0.7*self.log_scale_limit, size=self.model.dof_damping.shape)
                new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers)

            # if 'geom_friction' in self.rand_params:
            #     dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, 0.7*self.log_scale_limit, size=self.model.geom_friction.shape)
            #     new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers)

            param_sets.append(new_params)

        return param_sets
    

    def set_test_task(self):
        new_params = {}
        if 'body_mass' in self.rand_params:
            new_params['body_mass'] = self.init_params['body_mass'] * np.array(1.5)**(0.8*self.log_scale_limit)

        # body_inertia
        if 'body_inertia' in self.rand_params:
            new_params['body_inertia'] = self.init_params['body_inertia'] * np.array(1.5)**(0.8*self.log_scale_limit)

        # damping -> different multiplier for different dofs/joints
        if 'dof_damping' in self.rand_params:
            new_params['dof_damping'] = self.init_params['dof_damping'] * np.array(1.3)**(0.8*self.log_scale_limit)
        
        # if 'geom_friction' in self.rand_params:
        #     new_params['geom_friction'] = self.init_params['geom_friction'] * np.array(1.5) ** (0.8*self.log_scale_limit)

        self.set_task(new_params)
        self.tasks = [new_params]
            

    def set_train_task(self, num_tasks):
        self.tasks = self.sample_train_tasks(num_tasks)
    
    def reset_task(self,idx):
        self.set_task(self.tasks[idx])

    def set_task(self, task):
        temp_xml_path = self._create_modified_xml(task)
        try:
            self.model = mujoco_py.load_model_from_path(temp_xml_path)
            self.sim = mujoco_py.MjSim(self.model)
        finally:
            os.remove(temp_xml_path)
        if self.viewer is not None:
            self.viewer.update_sim(self.sim)
        self.cur_params = task
        self._goal = [task]
        
    def get_task(self):
        return self.cur_params

    def save_parameters(self):
        self.init_params = {}
        if 'body_mass' in self.rand_params:
            self.init_params['body_mass'] = self.model.body_mass.copy()

        # body_inertia
        if 'body_inertia' in self.rand_params:
            self.init_params['body_inertia'] = self.model.body_inertia.copy()

        # damping -> different multiplier for different dofs/joints
        if 'dof_damping' in self.rand_params:
            self.init_params['dof_damping'] = self.model.dof_damping.copy()

        # friction at the body components
        # if 'geom_friction' in self.rand_params:
        #     self.init_params['geom_friction'] = self.model.geom_friction.copy()
        # self.cur_params = self.init_params

    def set_seed(self,seed):
        self.seed(seed)

    def get_all_task_idx(self):
        return range(len(self.tasks))
