import os
import xml.etree.ElementTree as ET
from typing import List
import gymnasium as gym # ICI
from omegaconf import OmegaConf
import yaml
import itertools
import numpy as np
import hydra
from stable_baselines3 import PPO, DQN, RDQN ,RDQN2 , SAC , RSAC, TD3
import wandb
from utils import evaluate_model_cartpole, BasicWrapper, BasicWrapper_gravity, BasicWrapper_masscart, BasicWrapper_masspole ,BasicWrapper_mountainforce ,BasicWrapper_mountaingravity , BasicWrapper_mountainspeed, linear_schedule,Mass_Wrapper_random
from utils import evaluate_model_mujoco
from game import Game
import matplotlib.pyplot as plt
import os
from stable_baselines3.common.utils import ExpWeights

import numpy as np 
from typing import List


def make_env(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    if env_name == "HalfCheetah-v2":
        env = make_halfcheetah(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "InvertedPendulum-v2":
        env = make_invertedpendulum(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Walker2d-v2":
        env = make_walker(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Hopper-v2":
        env = make_hopper(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Ant-v2":
        env = make_ant(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "HumanoidStandup-v2":
        env = make_humanoidstandup(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    else:
        raise NotImplementedError()
    return env


def make_halfcheetah(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    #root = tree.getroot()
    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{xml_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{exp}_{xml_name}.xml"
    tree.write(fname)

    env = gym.make( id=env_name, xml_file=fname)
    
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "backthighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_invertedpendulum(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    env = gym.make(env_name, xml_file=os.path.basename(base_xml_file))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "polemass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "cartmass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_walker(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{xml_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{exp}_{xml_name}.xml"
    tree.write(fname)

    
    env = gym.make( id=env_name, xml_file=fname)
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "thighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env

def make_ant(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }
    env = gym.make( id=env_name)
    #env = gym.make(env_name, xml_file=os.path.basename(base_xml_file))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "frontleftlegmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "frontrightlegmass":
            env.env.model.body_mass[4] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_humanoidstandup(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    env = gym.make(env_name)
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "rightfootmass":
            env.env.model.body_mass[6] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "leftthighmass":
            env.env.model.body_mass[7] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_hopper(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
    exp=0.5,
):
    
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }
    
    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{xml_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{exp}_{xml_name}.xml"
    tree.write(fname)
    #print(os.path.basename(fname))
    env = gym.make( id=env_name, xml_file=fname)
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "thighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def evaluate(args,evaluate_num,seed,best_model,best=False,expectile_max=False,auto=False):

    
    config=args
   
    num_eval_point=config['num_eval_point']
    dim_evaluate_point=config["dim_evaluate_point"]

    config["seed"]=seed
    print(type(num_eval_point))
    print(config["change_param_min"][0],config["change_param_max"][0])

   # Create the test env
    if dim_evaluate_point == 1:
            evaluate_points = np.linspace(
                config["change_param_min"],
                config["change_param_max"],
                num_eval_point,
            )
    elif dim_evaluate_point == 2:
            eval_list0 = np.linspace(
                config["change_param_min"][0],
                config["change_param_max"][0],
                num_eval_point,
            )
            eval_list1 = np.linspace(
                config["change_param_min"][1],
                config["change_param_max"][1],
                num_eval_point,
            )
            evaluate_points = list(itertools.product(eval_list0, eval_list1))
    elif dim_evaluate_point == 3:
            eval_list0 = np.linspace(
                config["change_param_min"][0],
                config["change_param_max"][0],
                num_eval_point,
            )
            eval_list1 = np.linspace(
                config["change_param_min"][1],
                config["change_param_max"][1],
                num_eval_point,
            )
            eval_list2 = np.linspace(
                config["change_param_min"][2],
                config["change_param_max"][2],
                num_eval_point,
            )
            evaluate_points = list(itertools.product(eval_list0, eval_list1, eval_list2))

    means_rewards=[]
    mins_rewards=[]
    quantiles10=[]
    quantiles25=[]
    
    for count,evaluate_point in enumerate(evaluate_points):

        if config["xml_name"]=="hopper":
            
            envs = make_hopper(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    os.getcwd()+"/../../../../",
                    exp=config["expectile"],
                )
        
        if config["xml_name"]=="half_cheetah":
            envs = make_halfcheetah(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    os.getcwd()+"/../../../../",
                    exp=config["expectile"]
                )
        if config["xml_name"]=="walker2d":
            envs = make_walker(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    os.getcwd()+"/../../../../",
                    exp=config["expectile"]
                )

        
        if config["xml_name"]=="ant":
            envs = make_ant(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    os.getcwd()+"/../../../../",
                    exp=config["expectile"],
                )
        if config["xml_name"]=="humanoidstandup":
            envs = make_humanoidstandup(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    os.getcwd()+"/../../../../",
                    exp=config["expectile"],
                )
        print("evaluate",evaluate_point)
        mean_reward, std_reward, min_reward,quantile10, quantile25=evaluate_model_mujoco(best_model,envs,n_eval_episodes=evaluate_num,deterministic=True,name=config["change_param_names"],value=evaluate_point, best=best,expectile_max=expectile_max,auto=auto)
        means_rewards.append(mean_reward)
        mins_rewards.append(min_reward)
        quantiles10.append(quantile10)
        quantiles25.append(quantile25)

    ### evaluate on nominal
    env_nom=gym.make(config["env_name"])
    nom_mean_reward, nom_std_reward, nom_min_reward,nom_quantile10, nom_quantile25=evaluate_model_mujoco(best_model,env_nom,n_eval_episodes=evaluate_num,deterministic=True,name=config["change_param_names"],value=evaluate_point, best=best,expectile_max=expectile_max,auto=auto)


        
    final_mean=np.mean(means_rewards)
    final_min=np.min(means_rewards)
    final_quant10=np.quantile(means_rewards,0.1)
    final_quant25=np.quantile(means_rewards,0.25)

    if best==False:
        
        wandb.log({f"final/final_mean" :  final_mean , f"final/final_min" : final_min , f"final/final_quant10" : final_quant10, f"final/final_quant25" : final_quant25   , "final/final_mean_nom" :  nom_mean_reward, f"final/final_min_nom" : nom_min_reward , f"final/final_quant10_nom" : nom_quantile10, f"final/final_quant25_nom" : nom_quantile25 ,f"final/final_std_nom" : nom_std_reward          }        )
    else :
        wandb.log({f"finalbest/final_mean" :  final_mean , f"finalbest/final_min" : final_min , f"finalbest/final_quant10" : final_quant10, f"finalbest/final_quant25" : final_quant25 ,"finalbest/final_mean_nom" :  nom_mean_reward, f"finalbest/final_min_nom" : nom_min_reward , f"finalbest/final_quant10_nom" : nom_quantile10, f"finalbest/final_quant25_nom" : nom_quantile25 ,f"final/final_std_nom" : nom_std_reward      } )
       


class Wrapper_DR(gym.Wrapper):
    """
    Wrapper to change the relavite mass of Mujoco continous env
    """

    def __init__(self,env,arg,seed):
        #super().__init__(env)
        super(Wrapper_DR,self).__init__(env)
        config=arg
        self.config=arg
        num_eval_point=config['num_eval_point']
        dim_evaluate_point=config["dim_evaluate_point"]
        
        self.dim_evaluate_point=dim_evaluate_point
    # Create the test env
        if dim_evaluate_point == 1:
                evaluate_points = np.linspace(
                    config["change_param_min"],
                    config["change_param_max"],
                    num_eval_point,
                )
        elif dim_evaluate_point == 2:
                eval_list0 = np.linspace(
                    config["change_param_min"][0],
                    config["change_param_max"][0],
                    num_eval_point,
                )
                eval_list1 = np.linspace(
                    config["change_param_min"][1],
                    config["change_param_max"][1],
                    num_eval_point,
                )
                evaluate_points = list(itertools.product(eval_list0, eval_list1))
        elif dim_evaluate_point == 3:
                eval_list0 = np.linspace(
                    config["change_param_min"][0],
                    config["change_param_max"][0],
                    num_eval_point,
                )
                eval_list1 = np.linspace(
                    config["change_param_min"][1],
                    config["change_param_max"][1],
                    num_eval_point,
                )
                eval_list2 = np.linspace(
                    config["change_param_min"][2],
                    config["change_param_max"][2],
                    num_eval_point,
                )
                evaluate_points = list(itertools.product(eval_list0, eval_list1, eval_list2))

        self.evaluate_points=evaluate_points
        self.count=0
        
    def reset(self,seed=None, options=None):
        super().reset()
        param=self.evaluate_points[np.random.randint(len(self.evaluate_points))]
        
        if self.config["xml_name"]=='hopper':
            envs = make_hopper(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"]  ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="half_cheetah":
            envs = make_halfcheetah(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="walker2d":
            envs = make_walker(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )

    
        if self.config["xml_name"]=="ant":
            envs = make_ant(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="humanoidstandup":
            envs = make_humanoidstandup(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        self.env=envs
        
        return self.env.reset()
    
   

class Wrapper_auto(gym.Wrapper):
    """
    Wrapper to change the relavite mass of Mujoco continous env
    """

    def __init__(self,env):
        
        super(Wrapper_auto,self).__init__(env)
        self.ep_rew=0
       
    def get_sampled_expectile(self):
        return self.bandits.sampled_expectile
        

    def reset(self,seed=None, options=None):

        observation,info = self.env.reset()
        self.rewards = []
        self.previous_ep_rew=self.ep_rew
    
        return observation,info
        
    def step(self, action):
        observation, reward, terminated ,truncated,info = self.env.step(action)
        self.rewards.append(float(reward))
        if terminated or truncated:
            self.ep_rew = sum(self.rewards)
            

        return observation, reward, terminated ,truncated , info



class Wrapper_auto_DR(gym.Wrapper):
    """
    Wrapper to change the relavite mass of Mujoco continous env
    """

    def __init__(self,env,arg,seed):
        super(Wrapper_auto_DR,self).__init__(env)
        self.bandits=ExpWeights()
        self.ep_rew=0
        self.bandits.sampled_expectile=self.bandits.sample()
        config=arg
        self.config=arg
        num_eval_point=config['num_eval_point']
        dim_evaluate_point=config["dim_evaluate_point"]
        self.nominal_env=gym.make(self.config["env_name"])
        self.dim_evaluate_point=dim_evaluate_point
    # Create the test env
        if dim_evaluate_point == 1:
                evaluate_points = np.linspace(
                    config["change_param_min"],
                    config["change_param_max"],
                    num_eval_point,
                )
        elif dim_evaluate_point == 2:
                eval_list0 = np.linspace(
                    config["change_param_min"][0],
                    config["change_param_max"][0],
                    num_eval_point,
                )
                eval_list1 = np.linspace(
                    config["change_param_min"][1],
                    config["change_param_max"][1],
                    num_eval_point,
                )
                evaluate_points = list(itertools.product(eval_list0, eval_list1))
        elif dim_evaluate_point == 3:
                eval_list0 = np.linspace(
                    config["change_param_min"][0],
                    config["change_param_max"][0],
                    num_eval_point,
                )
                eval_list1 = np.linspace(
                    config["change_param_min"][1],
                    config["change_param_max"][1],
                    num_eval_point,
                )
                eval_list2 = np.linspace(
                    config["change_param_min"][2],
                    config["change_param_max"][2],
                    num_eval_point,
                )
                evaluate_points = list(itertools.product(eval_list0, eval_list1, eval_list2))

        self.evaluate_points=evaluate_points
        self.count=0

    def get_sampled_expectile(self):
        return self.bandits.sampled_expectile
        
    def reset(self,seed=None, options=None):

        
        param=self.evaluate_points[np.random.randint(len(self.evaluate_points))]
        
        if self.config["xml_name"]=='hopper':
            envs = make_hopper(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"]  ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="half_cheetah":
            envs = make_halfcheetah(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="walker2d":
            envs = make_walker(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )

    
        if self.config["xml_name"]=="ant":
            envs = make_ant(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
        if self.config["xml_name"]=="humanoidstandup":
            envs = make_humanoidstandup(
                self.config["env_name"],
                self.config["change_param_names"],
                param,
                self.config["seed"] ,
                "M2TD3",
                self.config["xml_file"],
                self.config["xml_name"],
                os.getcwd()+"/../../../../",
                exp=self.config["expectile"],
            )
            # new env parameters with DR
        self.env=envs
        ######## par for computing reward
        self.rewards = []
        self.previous_ep_rew=self.ep_rew

        return self.env.reset()

    def step(self, action):

        observation, reward, terminated,truncated,info = self.env.step(action)
        self.rewards.append(float(reward))
        if terminated or truncated:
            self.ep_rew = sum(self.rewards)
            

           
        
        return observation, reward, terminated , truncated,info




    

    