import os
import xml.etree.ElementTree as ET
from typing import List

import gym

from omegaconf import OmegaConf
import yaml
import itertools
import numpy as np
import hydra



from stable_baselines3 import PPO, DQN, RDQN ,RDQN2 , SAC , RSAC
def make_env(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path = "../../",
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    if env_name == "HalfCheetah-v2":
        env = make_halfcheetah(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "InvertedPendulum-v2":
        env = make_invertedpendulum(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Walker2d-v2":
        env = make_walker(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Hopper-v2":
        env = make_hopper(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "Ant-v2":
        env = make_ant(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    elif env_name == "HumanoidStandup-v2":
        env = make_humanoidstandup(
            env_name,
            change_param_names,
            change_param_values,
            seed,
            alg_name,
            base_xml_file,
            xml_name,
            gym_path,
        )
    else:
        raise NotImplementedError()
    return env


def make_halfcheetah(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{env_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{xml_name}.xml"
    tree.write(fname)
    env = gym.make(env_name, xml_file=os.path.basename(fname))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "backthighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_invertedpendulum(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    env = gym.make(env_name, xml_file=os.path.basename(base_xml_file))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "polemass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "cartmass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_walker(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{env_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{xml_name}.xml"
    tree.write(fname)
    env = gym.make(env_name, xml_file=os.path.basename(fname))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "thighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_hopper(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{env_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{xml_name}.xml"
    tree.write(fname)
    env = gym.make(env_name, xml_file=os.path.basename(fname))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "thighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_ant(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    env = gym.make(env_name, xml_file=os.path.basename(base_xml_file))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "frontleftlegmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "frontrightlegmass":
            env.env.model.body_mass[4] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env


def make_humanoidstandup(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    env = gym.make(env_name, xml_file=os.path.basename(base_xml_file))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "rightfootmass":
            env.env.model.body_mass[6] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "leftthighmass":
            env.env.model.body_mass[7] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env




def make_hopper(
    env_name,
    change_param_names,
    change_param_values,
    seed,
    alg_name,
    base_xml_file,
    xml_name,
    gym_path,
):
    '''Create an environment with specified physical parameters

    Parameters
    ----------
    env_name : str
        environement name
    change_param_names : List[str]
        Names of physical parameter to be changed
    change_param_values : List[float]
        Values of physical parameter to be changed
    seed : int
        random seed
    alg_name : str
        name of algorithm
    base_xml_file : str
        Xml file path of mujoco
    xml_name : str
        Prefix xml file
    gym_path : str
        Path of gym library

    '''

    change_flags = {
        change_param_name: False for change_param_name in change_param_names
    }

    tree = ET.parse(f"{gym_path}gym/envs/mujoco/assets/{base_xml_file}")
    root = tree.getroot()
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        for name in root.iter("default"):
            for n in name.iter("geom"):
                l = n.attrib["friction"].split()
                if change_param_name == "worldfriction":
                    l[0] = str(change_param_value)  # world friction
                    n.attrib["friction"] = " ".join(l)
                    change_flags[change_param_name] = True
    fname = f"{gym_path}gym/envs/mujoco/assets/{env_name}_{len(change_param_names)}_new_{alg_name}_{seed}_{xml_name}.xml"
    tree.write(fname)
    env = gym.make(env_name, xml_file=os.path.basename(fname))
    for change_param_name, change_param_value in zip(
        change_param_names, change_param_values
    ):
        if change_param_name == "torsomass":
            env.env.model.body_mass[1] = change_param_value
            change_flags[change_param_name] = True
        if change_param_name == "thighmass":
            env.env.model.body_mass[2] = change_param_value
            change_flags[change_param_name] = True
    for key, value in change_flags.items():
        if not value:
            print("not change", key)
            raise RuntimeError()
    return env

# Hopper
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Hopperv2-1_2.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Hopperv2-1_3.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Hopperv2-2_3.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Hopperv2-3_3_3_4.yaml"

# Anr

# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Antv2-1_3.yaml"
# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Antv2-2_3_3.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Antv2-3_3_3_3.yaml"

# # Walker2d

# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Walker2dv2-1_4.yaml"
# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Walker2dv2_2_4_5.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Walker2dv2-3_4_5_6.yaml"

# # HalfCheetah
# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/HalfCheetahv2-1_3.yaml"
# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/HalfCheetahv2-1_4.yaml"
# confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/HalfCheetahv2-2_4_7.yaml"
confif_path="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/HalfCheetahv2-3_4_7_4.yaml"


# "/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Antv2-1_3.yaml"





#import gym
import gymnasium as gym
from stable_baselines3 import PPO, DQN, RDQN ,RDQN2 , SAC , RSAC, TD3
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder, DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
import torch
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from stable_baselines3.common.vec_env import (VecEnv, VecMonitor,
                                              is_vecenv_wrapped)
import matplotlib.pyplot as plt
import numpy as np 
from tqdm import tqdm
import warnings
from utils import evaluate_model_cartpole, BasicWrapper, BasicWrapper_gravity, BasicWrapper_masscart, BasicWrapper_masspole ,BasicWrapper_mountainforce ,BasicWrapper_mountaingravity , BasicWrapper_mountainspeed, linear_schedule,Mass_Wrapper_random,evaluate_model_mujoco
import json 
import argparse
from game import Game

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
from utils import evaluate_model_mujoco , SuperEvalCallback




def collect_argparser():

    parser = argparse.ArgumentParser()
    parser.add_argument(
    "--log_dir",
    default="./",
    type=str,
    help="directory where results are saved",
)

    parser.add_argument(
        "--name_exp",
        type=str,
        default="mujoco",
        help="name of the experience",
    )

    parser.add_argument(
        "--env", type=str, default="Hopper-v2", help="env id for simulation"
    )

    parser.add_argument(
        "--model",
        type=str,
        default="TD3",
        help="the model use for simulation,default TQC, QRDQN",
    )


    parser.add_argument(
            "--learning_rate", default=7e-4, type=float, help="Intial LR "
        )


    parser.add_argument(
            "--expectile", default=0.5, type=float, help="Intial LR "
        )


    parser.add_argument(
            "--total_timesteps",
            default=50000,
            type=int,
            help="number timesteps for the algorithm ",
        )

    parser.add_argument(
            "--n_envs",
            default=8,
            type=int,
            help="number timesteps for the algorithm ",
        )


    parser.add_argument(
            "--seed",
            default=234,
            type=int,
            help="seed ",
        )
    
    parser.add_argument(
            "--n_expectation",
            default=1,
            type=int,
            help="N-exectation ",
        )
    
    parser.add_argument("--evaluate_num", type=int, default=30)

    return  parser.parse_args()



def evaluate(config_path,evaluate_num,seed,model_path=None):

    if model_path is not None:
        best_model=TD3.load(
             model_path + "/best_model.zip")

    with open(confif_path, "r") as f:
            config=yaml.safe_load(f)
            #config = OmegaConf.load(f)

    num_eval_point=config["num_eval_point"]
    dim_evaluate_point=config["dim_evaluate_point"]

    config["seed"]=seed

   # Create the test env
    if dim_evaluate_point == 1:
            evaluate_points = np.linspace(
                config["change_param_min"],
                config["change_param_max"],
                num_eval_point,
            )
    elif dim_evaluate_point == 2:
            eval_list0 = np.linspace(
                config["change_param_min"][0],
                config["change_param_max"][0],
                num_eval_point,
            )
            eval_list1 = np.linspace(
                config["change_param_min"][1],
                config["change_param_max"][1],
                num_eval_point,
            )
            evaluate_points = list(itertools.product(eval_list0, eval_list1))
    elif dim_evaluate_point == 3:
            eval_list0 = np.linspace(
                config["change_param_min"][0],
                config["change_param_max"][0],
                num_eval_point,
            )
            eval_list1 = np.linspace(
                config["change_param_min"][1],
                config["change_param_max"][1],
                num_eval_point,
            )
            eval_list2 = np.linspace(
                config["change_param_min"][2],
                config["change_param_max"][2],
                num_eval_point,
            )
            evaluate_points = list(itertools.product(eval_list0, eval_list1, eval_list2))


    means_rewards=[]
    mins_rewards=[]
    quantiles10=[]
    quantiles25=[]
    
    for count,evaluate_point in enumerate(evaluate_points):

        if config["xml_name"]=="hopper":
            

            #importing the os module
         
            envs = make_hopper(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    "./",
                )
        
        if config["xml_name"]=="half_cheetah":
            envs = make_halfcheetah(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    "./",
                )
        if config["xml_name"]=="walker2d":
            envs = make_walker(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    "./",
                )

        
        if config["xml_name"]=="ant":
            envs = make_ant(
                    config["env_name"],
                    config["change_param_names"],
                    evaluate_point,
                    config["seed"] + count,
                    "M2TD3",
                    config["xml_file"],
                    config["xml_name"],
                    "./",
                )
        mean_reward, std_reward, min_reward,quantile10, quantile25=evaluate_model_mujoco(best_model,envs,n_eval_episodes=evaluate_num,deterministic=True,name=config["change_param_names"])
        means_rewards.append(mean_reward)
        min_reward.append(min_reward)
        quantiles10.append(quantile10)
        quantiles25.append(quantile25)

    final_mean=np.mean(means_rewards)
    final_min=np.min(means_rewards)


    wandb.log({f"final/final_mean" :  final_mean , f"final/final_min" : final_min } )



@hydra.main(config_path="configs/environments", config_name="config")
def main():

    # def get_args(cfg: DictConfig):
    #     cfg.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    #     cfg.hydra_base_dir = os.getcwd()
    #     return cfg
    
    config = {
                "policy_type": "MlpPolicy",
                "total_timesteps": args.total_timesteps,
                "env": args.env,
                
            }
    
    args.name_exp=args.env
    seed=np.random.randint(0,1000000)
    run = wandb.init(
                        project=args.name_exp,
                        config=config,
                        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                        monitor_gym=False,  # auto-upload the videos of agents playing the game
                        save_code=False,  # optional
                    )
    wandbCallback=WandbCallback(
                            model_save_path=f"models/{run.id}",
                            verbose=2,)
                            
    vec_env = make_vec_env(config["env"], n_envs=args.n_envs,seed=args.seed + seed,vec_env_cls=DummyVecEnv)#vec_env_kwargs=dict(start_method='forkserver'), 
    #vec_env = VecNormalize(vec_env , training=True, norm_obs=True, norm_reward=False)
    supercallback = SuperEvalCallback(
                        eval_env=vec_env,
                        eval_freq=4000,
                        total_timesteps=args.total_timesteps,
                        verbose=1,
                        best_model_save_path="./model/" + f"best_model{run.id}",)
    
    callback = CallbackList([wandbCallback,supercallback])

    model = TD3(config["policy_type"], vec_env, verbose=1, tensorboard_log=f"runs/{run.id} "
                    , batch_size= 256    
                    ,learning_starts= 10000,
                    gamma= 0.98,
                    train_freq= 16,
                    tau=0.005,
                    target_update_interval=1,
                    gradient_steps= 1,
                    learning_rate=3e-4,
                    buffer_size=300000,
                     expectile=0.5
                        )
    model.learn(
                        total_timesteps=args.total_timesteps,
                        callback=callback,
                        progress_bar=True
                    )


if __name__ == "__main__":
  
    args = collect_argparser()
    main()

