import argparse
import json
import os
from itertools import product
import torch
import ray
from lift.expert_policy import (
    CoordinateWalkPolicy,
    OptimalPolicy,
)
from lift.policy_augmentations import (
    GaussianAugmentation,
    VoidAugmentation,
    IORLAugmentation,
    LIFTAugmentation,
    RandomAugmentation,
    StepSizeAugmentation,
)
from lift.evaluation import Evaluation
from lift.environments import LensEnv, LensPositioning, PositionOnly, LightTunnel



def _get_experiment_name(config):
    return "_".join(
        [
            config["policy"],
            str(config["policy_step_length"]),
            config['augmentor'],
            str(config['n_augmentations']),
            str(config['n_episodes_per_collection']),
        ]
    )


def _make_gaussian(*args, **kwargs):
    return GaussianAugmentation(*args, **kwargs)


def _make_random(*args, **kwargs):
    return RandomAugmentation(*args, **kwargs, p=0.4)


def _make_step(*args, **kwargs):
    return StepSizeAugmentation(*args, **kwargs, p=0.4)


def _make_void(*args, **kwargs):
    return VoidAugmentation(*args, **kwargs)


def _make_iorl_simple(*args, **kwargs):
    return IORLAugmentation(
        *args,
        **kwargs,
        learning_rate=0.00001,
        train_epochs=50,
        min_data=100,
        observation_shape=(5,),
        hidden_dim=4,
        device='cpu',
    )


def _build_lift_augmentor(
    train_epochs=10,
    shortcuts=False,
    p=0.4,
    min_data=50,
    model='CQL',
    train_once=True,
    step_norm_factor=1,
):
    def _make_augmentor(*args, **kwargs):
        return LIFTAugmentation(
            *args,
            **kwargs,
            p=p,
            train_once=train_once,
            step_norm_factor=step_norm_factor,
            min_data=min_data,  # number of trajectories
            train_epochs=train_epochs,
            use_shortcuts=shortcuts,
            model_class=model,
        )

    return _make_augmentor


def get_augmentor_factories(experiment_name, n_actions=5):
    min_data = 50 if n_actions==2 else 100

    if experiment_name == 'effect_of_p':
        augmentor_factories = {
        'lift_p02': _build_lift_augmentor(train_epochs=50, train_once=True, min_data=min_data, shortcuts=True,step_norm_factor=-1, p=0.2),
        'lift_p04': _build_lift_augmentor(train_epochs=50, train_once=True, min_data=min_data, shortcuts=True,step_norm_factor=-1, p=0.4),
        'lift_p06': _build_lift_augmentor(train_epochs=50, train_once=True, min_data=min_data, shortcuts=True,step_norm_factor=-1, p=0.6),
        'lift_p08': _build_lift_augmentor(train_epochs=50, train_once=True, min_data=min_data, shortcuts=True,step_norm_factor=-1, p=0.8),
        }
    elif experiment_name == 'augment':
        augmentor_factories = {
            'gaussian': _make_gaussian,
            'random': _make_random,
            'void': _make_void,
            'lift': _build_lift_augmentor(train_epochs=50, train_once=True, min_data=min_data, shortcuts=True,step_norm_factor=-1, p=0.4),
            'step_size': _make_step,
            'iorl': _make_iorl_simple
        }
    else:
        augmentor_factories = {
            'void': _make_void,
            'lift_sc': _build_lift_augmentor(
                train_epochs=20, train_once=True, min_data=min_data, shortcuts=True, step_norm_factor=-1, p=0.6
            ),
        }
    return augmentor_factories

def load_env(config):
    if config["env"] == "lp":
        env = LensPositioning(
            env_cls=LensEnv,
            noise_objects=0.0,
            noise_movement=0.1,
            n_actions=config['n_actions'],
            reward='score',
            score='distance',
            distortion=config["distortion"],
            score_goal_threshold=0.01,
            max_episode_steps=100,
            sample_count=256,
            ref_pattern="siemens",
            hr=False,
            config=None,
            width=200,
            height=200,
        )
        return env

    if config["env"] == "po":
        env = PositionOnly(
            n_actions=config['n_actions'],
            distortion=config["distortion"],
            noise_movement=0.1,
            score_goal_threshold=0.01,
            max_episode_steps=1000,
        )
        return env

    if config['env'] == 'lt':
        env = LightTunnel(
            n_actions=config['n_actions'],
            distortion=config["distortion"],
            noise_movement=0.1,
            score_goal_threshold=0.01,
            max_episode_steps=100
        )
        return env


    raise NotImplementedError(f"Environment {config['env']} is not implemented.")


def load_policy(config, env):

    if config["env"] == "lp" or config["env"] == "po" or config["env"] == "lt":

        policy_coordinate = CoordinateWalkPolicy(
            env=env,
            initial_step_length=config["policy_step_length"],
        )

        policy_optimal = OptimalPolicy(
            env=env,
            max_step_length=config["policy_step_length"],
        )

        if config['policy'] == 'coordinate_walk':
            return policy_coordinate

        if config["policy"] == "optimal":
            return policy_optimal

    raise NotImplementedError(f"Policy {config['policy']} is not implemented ")


@ray.remote
def run(config, policy_step_length=None, policy=None, augmentor=None, distortion="movement"):
    if distortion is not None:
        config["distortion"] = distortion

    if policy is not None:
        config["policy"] = policy

    if policy_step_length is not None:
        config["policy_step_length"] = policy_step_length

    if augmentor is not None:
        config["augmentor"] = augmentor

    working_dir = os.environ.get("WORKING_DIR", ".")

    env = load_env(config)
    env_eval = load_env(config)

    policy = load_policy(config, env)

    augmentor = augmentor_factories[config['augmentor']](
        policy, env, env_eval=env_eval, n_augmentations=config['n_augmentations']
    )

    evaluation = Evaluation(
        env=env,
        expert_policy=policy,
        n_episodes_per_collection=config['n_episodes_per_collection'],
        n_collection_runs=config["n_collections"],
        max_transitions=100_000,
        device='cuda:0' if torch.cuda.is_available() else 'cpu',
    )

    scores, datasets = evaluation.evaluate_augmentor(augmentor, train_models=False)

    save_path = os.path.join(
        working_dir,
        f"{config['path']}/{config['env']}_{config['n_actions']}_{config['distortion']}",
    )
    os.makedirs(save_path, exist_ok=True)

    name = _get_experiment_name(config)

    path_scores = os.path.join(save_path, f"{name}_scores.json")
    with open(path_scores, 'w') as f:
        json.dump(scores, f)

    for i, dataset in enumerate(datasets):
        path_data = os.path.join(save_path, f"{name}_data_{i}.pkl")
        dataset.save(path_data)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', default="po", type=str)
    parser.add_argument('--n_actions', default=5, type=int)
    parser.add_argument('--augmentor', default="void", type=str)
    parser.add_argument('--n_augmentations', default=20, type=int)
    parser.add_argument('--n_collections', default=3, type=int)
    parser.add_argument('--n_episodes_per_collection', default=500, type=int)
    parser.add_argument('--path', default="eval", type=str)
    parser.add_argument('--benchmark', default="default", type=str)

    config = parser.parse_args().__dict__

    ray.init(address='auto')

    tasks = []

    if config['benchmark']=='default':
        step_lengths = [0.0125, 0.025, 0.05, 0.1]
        distortions = ["movement",'regional_rotation', 'scale', 'rotation', 'sinusodial', 'sqrt']
    else:
        step_lengths = [0.025]
        distortions = ["movement"]


    num_gpus = 0 if config['env'] == 'po' else 1
    augmentor_factories = get_augmentor_factories(config['benchmark'], config['n_actions'])

    tasks.extend(
        [
            run.options(num_cpus=1, num_gpus=num_gpus).remote(
                config=config,
                policy_step_length=length,
                policy="coordinate_walk",
                augmentor=augmentor,
                distortion=distortion,
            )
            for length, augmentor, distortion in product(
                step_lengths, list(augmentor_factories.keys()), distortions
            )
        ]
    )

    ray.get(tasks)
