import os
from os.path import join, abspath
import rlkit
from math import sqrt, log
import numpy as np

"""
Debug mode will
* skip confirmation when replacing directories
* change the data directory to ./tmp
* turn off wandb logging
"""
try:
    from rlkit.conf_private import DEBUG
except ImportError:
    DEBUG = False


class CheckpointParams:
    checkpoint_path = "/local/home/<USER>/final-models"

    class Q:
        envs = [
            "halfcheetah-medium-expert-v2",
            "halfcheetah-medium-replay-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-expert-v2",
            "walker2d-medium-replay-v2",
            "walker2d-medium-v2",
        ]
        seeds = [0, 1, 2, 3, 4, 5, 6, 7]
        itrs = {
            "halfcheetah-medium-expert-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "halfcheetah-medium-replay-v2": [
                -1900.0,
                -1800.0,
                -1700.0,
                -1600.0,
                -1500.0,
                -1400.0,
                -1300.0,
                -1200.0,
                -1100.0,
                -1000.0,
                -900.0,
                -800.0,
                -700.0,
                -600.0,
                -500.0,
                -400.0,
                -300.0,
                -200.0,
                -100.0,
                0.0,
            ],
            "halfcheetah-medium-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "hopper-medium-expert-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "hopper-medium-replay-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "hopper-medium-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "walker2d-medium-expert-v2": [
                -350.0,
                -300.0,
                -250.0,
                -200.0,
                -150.0,
                -100.0,
                -50.0,
                0.0,
            ],
            "walker2d-medium-replay-v2": [
                -1400.0,
                -1300.0,
                -1200.0,
                -1100.0,
                -1000.0,
                -900.0,
                -800.0,
                -700.0,
                -600.0,
                -500.0,
                -400.0,
                -300.0,
                -200.0,
                -100.0,
                0.0,
            ],
            "walker2d-medium-v2": [-600.0, -500.0, -400.0, -300.0, -200.0, -100.0, 0.0],
        }
        validation_optimal_epochs = {
            "halfcheetah-medium-expert-v2": 300,
            "halfcheetah-medium-replay-v2": 2000,
            "halfcheetah-medium-v2": 250,
            "hopper-medium-expert-v2": 300,
            "hopper-medium-replay-v2": 400,
            "hopper-medium-v2": 200,
            "walker2d-medium-expert-v2": 300,
            "walker2d-medium-replay-v2": 1100,
            "walker2d-medium-v2": 600,
        }
        key = "trainer/qfs"
        path = "q-normalize-env-256-256-256"

    class Q_IQN:
        envs = [
            "halfcheetah-medium-expert-v2",
            "halfcheetah-medium-replay-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-expert-v2",
            "walker2d-medium-replay-v2",
            "walker2d-medium-v2",
            "antmaze-umaze-v0",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-medium-play-v0",
        ]
        seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        itrs = {
            "hopper-medium-expert-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "halfcheetah-medium-expert-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "hopper-medium-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "halfcheetah-medium-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "walker2d-medium-expert-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "hopper-medium-replay-v2": [50, 100, 150, 200, 250, 300, 350, 400],
            "walker2d-medium-replay-v2": [
                100,
                200,
                300,
                400,
                500,
                600,
                700,
                800,
                900,
                1000,
                1100,
                1200,
                1300,
                1400,
                1500,
            ],
            "halfcheetah-medium-replay-v2": [
                100,
                200,
                300,
                400,
                500,
                600,
                700,
                800,
                900,
                1000,
                1100,
                1200,
                1300,
                1400,
                1500,
            ],
            "walker2d-medium-v2": [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000],
            "antmaze-umaze-v0": [100, 200, 300, 400, 500],
            "antmaze-umaze-diverse-v0": [100, 200, 300, 400, 500],  
            "antmaze-medium-diverse-v0": [100, 200, 300, 400, 500], 
            "antmaze-medium-play-v0": [100, 200, 300, 400, 500],
        }
        validation_optimal_epochs = {
            "halfcheetah-medium-expert-v2": 400,
            "halfcheetah-medium-replay-v2": 1500,
            "halfcheetah-medium-v2": 200,
            "hopper-medium-expert-v2": 300,
            "hopper-medium-replay-v2": 400,
            "hopper-medium-v2": 400,
            "walker2d-medium-expert-v2": 400,
            "walker2d-medium-replay-v2": 1100,
            "walker2d-medium-v2": 600,
            "antmaze-umaze-v0": 500,
            "antmaze-umaze-diverse-v0": 500,
            "antmaze-medium-diverse-v0": 500,
            "antmaze-medium-play-v0": 500,
        }
        key = "trainer/qfs"
        path = "q-iqn-normalize-env-256-256-256"

    class Q_IQL:
        envs = [
            "antmaze-umaze-v0",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-medium-play-v0",
            "antmaze-large-play-v0",
            "antmaze-large-diverse-v0",
            "halfcheetah-medium-expert-v2",
            "halfcheetah-medium-replay-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-expert-v2",
            "walker2d-medium-replay-v2",
            "walker2d-medium-v2",
        ]
        seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        path = "iql-models"

    class SG:
        envs = [
            "antmaze-umaze-v0",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-medium-play-v0",
            "halfcheetah-medium-expert-v2",
            "halfcheetah-medium-replay-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-expert-v2",
            "walker2d-medium-replay-v2",
            "walker2d-medium-v2",
        ]
        seeds = list(range(10))
        file = "params.pt"  
        key = "trainer/policy"
        path = "sg-no-ai-tricks-256-256-256" 

    class MG2:
        envs = [
            "halfcheetah-medium-expert-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "walker2d-medium-expert-v2",
        ]
        seeds = list(range(10))
        file = "params.pt"
        key = "trainer/policy"
        path = "mg-2-gaussian-bc-no-ai-tricks-256-256-256"  

    class MG4:
        envs = [
            "antmaze-umaze-v0",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-medium-play-v0",
            "walker2d-medium-expert-v2",
            "hopper-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
            "hopper-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "hopper-medium-v2",
            "halfcheetah-medium-v2",
            "walker2d-medium-v2",
        ]
        seeds = list(range(10))
        key = "trainer/policy"
        file = "params.pt"
        path = "mg-4-gaussian-bc-no-ai-tricks-256-256-256"  

    class MGTrunc:
        envs = [
            "antmaze-umaze-v0",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-medium-play-v0",
            "antmaze-large-play-v0",
            "antmaze-large-diverse-v0",
        ]
        seeds = list(range(5)) + list(range(6, 10))
        key = "trainer/policy"
        file = "params.pt"
        path1 = "mg-10-gaussian-trunc-ai-tricks-256-256-256-256"  
        path2 = "4-gaussian-tanh-before"  


class Log:
    repo_dir = abspath(join(os.path.dirname(rlkit.__file__), os.pardir, ".."))
    try:
        from rlkit.conf_private import rootdir
    except ImportError:
        rootdir = repo_dir
    basedir = join(rootdir, "tmp" if DEBUG else "data")

    """
    What to do when a previous experiment already exists in the data folder
    REPLACE: Replace the old experiment
    CONTINUE: Load the old experiment and continue training
    EXIT: Ask the researcher to deal with it
    """
    conflict_policy = "REPLACE"


DISPLAY_WELCOME = True
ENSEMBLE_MODEL_ROOT = join(Log.rootdir, "models")


class GridSearch:
    class testing:
        delta = [1, 2]
        beta_LB = [0]

    class full:
        delta = [sqrt(log(x) * -2) for x in [1, 0.99, 0.9, 0.8, 0.3]]
        beta_LB = [0.1, 0.5, 2]

    class mg:
        
        
        beta_LB = [0.1, 1.0]
        

    class pac:
        beta_LB = [1.0, 0.1]

    class ensemble_size:
        ensemble_size = [1, 2, 5, 10]

    class q_trained_epochs:
        q_trained_epochs = [1000, 1250, 1500, 1750, 2000]

    class k_fold:
        fold_idx = [1, 2, 3]

    class trq_delta:
        delta = [1e-2, 1e-3, 1e-4]

    class deltas:
        delta_range = [
            [0.5, 1.0],
            [0, 1.5],
            [0.5, 1.5],
            [1, 2],
        ]
        
        
        
        
        
        

    class fine_grained_deltas:
        delta = [[round(i, 3)] for i in np.arange(0.5, 2.5, 0.25)]


class Parallel:
    class small:
        seeds = list(range(5)) + list(range(6, 10))
        envs = [
            
            
            
            
            "antmaze-large-diverse-v0",
            "antmaze-large-play-v0",
        ]

    class single:  
        seeds = range(1)
        envs = [
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
        ]

    class one_wide:  
        seeds = range(1)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
        ]

    class five_wide:
        
        seeds = range(5)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
        ]

    class no_expert:
        
        seeds = range(5)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
        ]

    class wide:
        
        seeds = range(10)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
        ]

    class all_wide:
        
        seeds = range(7)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
            "antmaze-umaze-diverse-v0",
            "antmaze-medium-diverse-v0",
            "antmaze-umaze-v0",
            "antmaze-medium-play-v2",
            "antmaze-large-play-v0",
            "antmaze-large-diverse-v0",
        ]

    class medium:
        seeds = range(3)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
        ]

    class full:
        seeds = range(6, 10)
        envs = [
            "halfcheetah-medium-expert-v2",
            "halfcheetah-medium-replay-v2",
            "halfcheetah-medium-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-replay-v2",
            "hopper-medium-v2",
            "walker2d-medium-expert-v2",
            "walker2d-medium-replay-v2",
            "walker2d-medium-v2",
        ]

    class mix:
        seeds = range(1)
        envs = [
            "hopper-medium-v2",
            "walker2d-medium-v2",
            "halfcheetah-medium-v2",
        ]

    class medium_replay:
        seeds = range(2)
        envs = [
            "hopper-medium-replay-v2",
            "walker2d-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
        ]

    class medium_expert:
        seeds = range(5, 10)
        envs = [
            "hopper-medium-expert-v2",
            "walker2d-medium-expert-v2",
            "halfcheetah-medium-expert-v2",
        ]

    class hopper:
        seeds = range(6, 7)
        envs = [
            "hopper-medium-replay-v2",
            "hopper-medium-expert-v2",
            "hopper-medium-v2",
        ]

    class hopper_medium:
        seeds = range(8)
        envs = [
            "hopper-medium-v2",
        ]

    class ant_maze:
        seeds = range(1)
        envs = ["antmaze-medium-diverse-v0", "antmaze-umaze-diverse-v0"]

    class ant_maze_pen:
        seeds = range(3)
        envs = [
            "antmaze-umaze-v2",
            "antmaze-umaze-diverse-v2",
            "antmaze-medium-diverse-v2",
            "antmaze-medium-play-v2",
            "pen-cloned-v1",
        ]

        
        

    class hard:
        seeds = range(5)
        envs = [
            "hopper-medium-replay-v2",
            "halfcheetah-medium-replay-v2",
            "hopper-medium-v2",
            "halfcheetah-medium-v2",
        ]

    class expert_and_random:
        seeds = range(3)
        envs = [
            "hopper-expert-v2",
            "walker2d-expert-v2",
            "halfcheetah-expert-v2",
            "hopper-random-v2",
            "walker2d-random-v2",
            "halfcheetah-random-v2",
        ]



try:
    from rlkit.conf_private import *
except ImportError:
    print("No personal conf_private.py found.")
