import argparse
import os
import pickle
import random
import torch
from run_experiment import collect_eval

# from utils import (
#     build_darkroom_data_filename,
#     build_darkroom_model_filename,
# )
import numpy as np

import time
# import common_args
# from dpt_envs import darkroom_env, bandit_env#, pendulum_env

#build_pendulum_data_filename,

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

print(torch.cuda.device_count() )

# walker walk
#reacher easy

BASE_RESULTS_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) + "/results"



def main():
    parser = argparse.ArgumentParser(
        "Curriculum Adversarial RL Experiment Runner",
        description="Launches an adversarial RL experiment",
    )

    parser.add_argument("--seeds", type=int, default=1)
    parser.add_argument(
        "--algorithm",
        type=str,
        default="baseline",
        choices=[
            "fix_rarl",
            "baseline",
            "rarl",
            "mas",
            "sgld",
            "force",
            "qarl_spdl",
            "qarl_linear",
            "qarl_point",
            "qarl_single",
            "fixed_force",
            "fixed_temp",
            "random",
        ],
    )
    parser.add_argument(
        "--domain_name",
        type=str,
        default="reacher_two_players",
        choices=[
            "acrobot_two_players",
            "ball_in_cup_two_players",
            "cartpole_two_players",
            "cheetah_two_players",
            "hopper_two_players",
            "pendulum_two_players",
            "quadruped_two_players",
            "reacher_two_players",
            "walker_two_players",
        ],
    )
    parser.add_argument(
        "--task_name",
        type=str,
        default="hard_vs_adversary",
        choices=[
            "balance_vs_adversary",
            "balance_sparse_vs_adversary",
            "swingup_vs_adversary",
            "swingup_sparse_vs_adversary",
            "balance_vs_adversary",
            "catch_vs_adversary",
            "walk_vs_adversary",
            "run_vs_adversary",
            "hop_vs_adversary",
            "reach_goal_vs_adversary",
            "reach_goal_vs_wind_adversary",
            "easy_vs_adversary",
            "hard_vs_adversary",
        ],
    )
    parser.add_argument("--horizon", type=int, default=500)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--bool_render", type=bool, default=False)
    parser.add_argument("--new_adv_max_force", type=float, default=None)
    parser.add_argument("--use_cuda", type=bool, default=True)
    parser.add_argument("--n_total_iterations", type=int, default=200) #200
    parser.add_argument("--n_evaluation_episodes", type=int, default=20)
    args = parser.parse_args()
    


    # task = ["cheetah","quadruped","walker"]
    # # task = ['balance', 'balance_sparse', 'swingup', 'swingup_sparse']
    # task = ['easy', 'hard']

    # for i, task_name in enumerate (["easy_vs_adversary","hard_vs_adversary"]):
    # #([ "balance_vs_adversary","balance_sparse_vs_adversary","swingup_vs_adversary","swingup_sparse_vs_adversary"]): #(["cheetah_two_players","quadruped_two_players","walker_two_players"]):
    #     args.task_name = task_name

        

    #     # exp_name = task[i]+'_run_baseline'
    #     exp_name = 'reacher_'+task[i]+'_baseline'
    exp_name = 'reacher_hard_baseline'
    env = exp_name + '_train'

#     print('base results directory: ', BASE_RESULTS_DIRECTORY)
#     print('exp name: ', exp_name)

    # Range of robustness values to evaluate
    if args.domain_name == "acrobot_two_players":
        first_metric_range = np.linspace(0.95, 1.05, 5)
        second_metric_range = np.linspace(0.95, 1.05, 5)
    elif args.domain_name  == "ball_in_cup_two_players":
        first_metric_range = np.linspace(0.01, 0.11, 5)
        second_metric_range = np.linspace(0.01, 0.11, 5)
    if args.domain_name  == "cartpole_two_players":
        if args.task_name in (
            "balance_vs_adversary",
            "balance_sparse_vs_adversary",
        ):
            first_metric_range = np.linspace(1, 20, 5)
            second_metric_range = np.linspace(0.5, 1.5, 5)
        elif args.task_name in (
            "swingup_vs_adversary",
            "swingup_sparse_vs_adversary",
        ):
            first_metric_range = np.linspace(0.05, 0.15, 5)
            second_metric_range = np.linspace(0.5, 1.5, 5)
        else:
            first_metric_range = np.array([])
            second_metric_range = np.array([])
    elif args.domain_name == "cheetah_two_players":
        first_metric_range = np.linspace(3, 9, 5)
        second_metric_range = np.linspace(0.1, 1.9, 5)
    elif args.domain_name  == "hopper_two_players":
        first_metric_range = np.linspace(1, 9, 5)
        second_metric_range = np.linspace(0.1, 1.9, 5)
    elif args.domain_name  == "pendulum_two_players":
        first_metric_range = np.linspace(0.05, 0.15, 5)
        second_metric_range = np.linspace(0.5, 1.5, 5)
    elif args.domain_name  == "quadruped_two_players":
        first_metric_range = np.linspace(65, 75, 5)
        second_metric_range = np.linspace(0.95, 1.05, 5)
    elif args.domain_name == "reacher_two_players":
        first_metric_range = np.linspace(0.02, 0.06, 5)
        second_metric_range = np.linspace(0.02, 0.05, 5)
    elif args.domain_name  == "walker_two_players":
        first_metric_range = np.linspace(5, 15, 5)
        second_metric_range = np.linspace(0.1, 1.3, 5)
    else:
        first_metric_range = np.array([])
        second_metric_range = np.array([])

    start_time = time.time()
    np.random.seed(0)
    random.seed(0)
    
    # for env in envs:
    n_hists = 2000
    for first_metric_value in first_metric_range:
        for second_metric_value in second_metric_range:

    
            results_dir = (
                BASE_RESULTS_DIRECTORY + f"/{exp_name}/algorithm___{args.algorithm}/first{first_metric_value}_second{second_metric_value}"
            )
            # print('results_dir: ', results_dir)
            experiment_args = {
                "domain_name": args.domain_name,
                "task_name": args.task_name,
                "horizon": args.horizon,
                "gamma": args.gamma,
                "bool_render": args.bool_render,
                "new_adv_max_force": args.new_adv_max_force,
                "use_cuda": args.use_cuda,
                "n_total_iterations": args.n_total_iterations,
                "n_evaluation_episodes": args.n_evaluation_episodes,
            }
            model, mdp = collect_eval(
                algorithm=args.algorithm, seed=0, results_dir=results_dir, **experiment_args
            )
            # spec = mdp.state_spec()
            trajs = []
            
            for j in range(n_hists): #n_hists
                (
                    context_states,
                    context_actions,
                    context_next_states,
                    context_rewards,
                ) = rollin_mdp(mdp, model, args.horizon)
                for k in range(3): # n_samples
                    query_state = mdp.reset()

                
                    optimal_action = model.draw_action(query_state)

                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': context_states,
                        'context_actions': context_actions,
                        'context_next_states': context_next_states,
                        'context_rewards': context_rewards,
                    
                    }
                
                    trajs.append(traj)
            filename_template = 'datasets/trajs_icml_{}.pkl'
            train_filepath = filename_template.format(env)
            # Ensure datasets folder exists
            if not os.path.exists('datasets'):
                os.makedirs('datasets', exist_ok=True)

            # Append the new trajectories to the file
            with open(train_filepath, 'ab') as file:
                pickle.dump(trajs, file)
                print(f"Saved {len(trajs)} trajectories for {first_metric_value} and {second_metric_value}")

        # # Check the length of the saved trajectories
        # try:
        #     with open(train_filepath, 'rb') as file:
        #         # Read and check the length of the file
        #         all_trajs = []
        #         while True:
        #             try:
        #                 all_trajs.extend(pickle.load(file))  # Accumulate all data
        #             except EOFError:
        #                 break
        #         print(f"Total trajectories in file: {len(all_trajs)}")
        # except FileNotFoundError:
        #     print(f"File not found: {train_filepath}")

        print(f"Data collection time: {time.time() - start_time}")
                



            



def rollin_mdp(mdp, model, horizon):
    states = []
    actions = []
    next_states = []
    rewards = []

    state = mdp.reset()
    for _ in range(horizon):

        action = model.draw_action(state)
        next_state, reward, absorbing, _ = mdp.step(action)

        noise_std = 0.02  # Adjust this value depending on how much noise you want

        # Generate Gaussian noise with mean 0 and standard deviation 'noise_std'
        noise = np.random.normal(loc=0.0, scale=noise_std, size=action.shape)

        # Add noise to the action
        action_with_noise = action + noise

        states.append(state)
        actions.append(action_with_noise)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    # print('states: ', states.shape, actions.shape, next_states.shape, rewards.shape)
    # (100, 2) (100, 5) (100, 2) (100,)


    return states, actions, next_states, rewards







if __name__ == '__main__':
    
    main()

 