import os
from argparse import Namespace

from ray.rllib.policy.policy import PolicySpec
from ray.tune import Experiment

from unrealpose import DEBUG, ROOT_DIR
from unrealpose.control.callbacks import CustomMetricCallback
from unrealpose.control.utils import SHARED_POLICY_ID, shared_policy_mapping_fn
from unrealpose.custom import mappoimbl


NUM_CAMERAS = 3
INDEPENDENT = False
if not INDEPENDENT:
    MULTI_AGENT = {
        'policies': {
            SHARED_POLICY_ID: PolicySpec(observation_space=None, action_space=None)
        },
        'policy_mapping_fn': shared_policy_mapping_fn
    }
else:
    MULTI_AGENT = {
        'policies': {
            f'camera_{c}': PolicySpec(observation_space=None, action_space=None)
            for c in range(NUM_CAMERAS)
        },
        'policy_mapping_fn': lambda agent_id, **kwargs: agent_id
    }

NUM_CPUS_FOR_DRIVER = 5
TRAINER_GPUS = 0.5  # Trainer GPU amount
NUM_GPUS_PER_WORKER = 0.5
NUM_ENVS_PER_WORKER = 4  # memory requirement grows with number of humans in the environment
NUM_CPUS_PER_WORKER = 1  # 1 is enough
NUM_EVALUATION_WORKER = 1
NUM_WORKERS = (15 if not DEBUG else 4)
ROLLOUT_FRAGMENT_LENGTH = 25
NUM_SAMPLING_ITERATIONS = 1
TRAIN_BATCH_SIZE = NUM_SAMPLING_ITERATIONS * ROLLOUT_FRAGMENT_LENGTH * NUM_WORKERS * NUM_ENVS_PER_WORKER
SGD_MINIBATCH_SIZE = TRAIN_BATCH_SIZE // 4

REWARD_DICT = {
    "teamonly": {
        'team_reward': 1.0,
    },
    "teamsmalliot": {
        'team_reward': 1.0,
        'iot_reward': 0.1,
    },
    "no_touch_humans": {
        'team_reward': 3.0,
        'anti_collision_reward': 1.0
    },
}

ENV_CONFIG = {
    "id": 'MultiviewPose-v0',
    'pose_model_config': os.path.join('configs', 'w32_256x192_17j_coco.yaml'),
    "algo": 'AuxMAPPOPartial',

    # ======================== Wrappers ========================
    "use_numerical": False,
    'in_evaluation': False,  # This is set to False differentiate between training and evaluation
    "reward_dict": ("teamonly", REWARD_DICT["teamonly"]),
    "teammate_stats_dim": 9,

    'args': Namespace(
        num_humans=7,
        env_name=f'C{NUM_CAMERAS}_6x6_h30_p35_10mx10m_x13y13z3',
    ),
    "ground_truth_observation": {
        'use': True,
        'args': {
            'gt_noise_scale': 20.0
        }
    },

    'truncate_observation': {
        'use': False,
        'args': {
            'num_observed_humans': 7  # track all humans
        }
    },
    "air_wall_outer": True,
    "air_wall_inner": {
        "use": False,
        "args": {
            "lower_bound": [-200, -200, 0],
            "higher_bound": [200, 200, 500]
        },
    },
    'place_cam': {
        'use': False,
        'args': {
            'num_place_cameras': 1,
            'place_on_slider': False,
        },
    },
    'fix_human': {
        'use': False,
        'args': {
            'random_init_states': False,
            'index_list': None,
        },
    },
    'movable_cam': {
        'use': True,
        'args': {
            'num_movable_cameras': NUM_CAMERAS
        },
    },
    'scale_human': {
        'use': False,
        'args': {
            'scale': 1,
            'index_list': None,
        },
    },

    'rule_based_rot': {
        'use': True,
        'args': {
            'use_gt3d': True,
            'num_rule_cameras': None,
            'index_list': [0]
        },
    },
    'slider': {
        'use': False,
        'args': {},
    },
    'reach_target_done': {
        'use': False,
        'args': {},
    },

    'remove_info': True,
    'convert_multi_discrete': False,
    'force_single_agent': False,

    'done_when_colliding': {
        'use': False,
        'args': {
            'threshold': 50.0,
            'collision_tolerance': 10,
        },
    },

    'communicate_monocular_3d': False,

    'aux_rewards': {
        'use': True,
        'args': {
            'centering': True,
            'distance': True,
            'obstruction': True,
            'iot': True,
            'anti_collision': False
        }
    },

    'ego_action': {
        'use': True,
    },

    'rot_limit': {
        'use': True,
        'args': {
            'pitch_low': -85.0,
            'pitch_high': 85.0,
            'yaw_low': -360,
            'yaw_high': 360,
        }
    },

    'partial_triangulation': {
        'use': True
    },

    'shapley_reward': True,

    'shuffle_cam_id': {
        'use': True,
    },

    # DEPRECATED
    'sa_reward_logger': False,
}


mappoimbl = Experiment(
    run='PPO',
    name='mappoimbl',
    stop={"timesteps_total": 1.0E6},
    checkpoint_freq=10,
    checkpoint_at_end=True,
    keep_checkpoints_num=None,
    max_failures=0,
    checkpoint_score_attr="min-custom_metrics/mpjpe_3d_mean",
    local_dir=os.path.join(ROOT_DIR, 'ray_results'),
    config={
        'env': 'urealpose-parallel',
        "framework": "torch",
        'callbacks': CustomMetricCallback,
        "env_config": ENV_CONFIG,

        # == Sampling ==
        'horizon': 500,
        'rollout_fragment_length': ROLLOUT_FRAGMENT_LENGTH,
        'batch_mode': 'truncate_episodes',
        # 'batch_mode': 'complete_episodes',

        # # == Training ==
        "num_cpus_for_driver": NUM_CPUS_FOR_DRIVER,
        "num_gpus": TRAINER_GPUS,
        'num_workers': NUM_WORKERS,
        'num_gpus_per_worker': NUM_GPUS_PER_WORKER,
        "num_envs_per_worker": NUM_ENVS_PER_WORKER,
        'num_cpus_per_worker': NUM_CPUS_PER_WORKER,
        'train_batch_size': TRAIN_BATCH_SIZE,
        # how many steps to collect for training per iteration
        "sgd_minibatch_size": SGD_MINIBATCH_SIZE,

        'gamma': 0.99,
        'shuffle_sequences': False,
        "entropy_coeff": 0,
        'num_sgd_iter': 16,
        "vf_loss_coeff": 0.1,
        'vf_clip_param': 1000.0,
        'grad_clip': 50.0,
        'lr': 5.0E-4,
        'lr_schedule': [
            (0, 5E-4),
            (200E3, 5E-4),
            (200E3, 1E-4),
            (400E3, 1E-4),
            (600E3, 5E-5),
            (600E3, 5E-5),
        ],

        'model': {
            'lstm_use_prev_action': False,
            'lstm_use_prev_reward': False,

            "custom_model": "imbl_model",
            "custom_model_config": {
                "num_cameras": NUM_CAMERAS,
                "max_num_humans": 7,
                'max_visible_num_humans': 5,
                'prediction_steps': 1,
                "cell_size": 128,
                'actnet_hiddens': [128],
                'vfnet_hiddens': [128],
                'fcnet_hiddens': [128, 128, 128],
                'mdn_hiddens': [128, 128],
                'mdn_num_gaussians': 16,
                'masking_target': True,
                'observation_sorting': False,
                'merge_back': True,
                'coordinate_scale': 500.0,

                "prediction_loss_coeff": 1.0,  # Main Coeff
                "pred_coeff_dict": {  # Sub Coeffs
                    'coeff_cam_pred': 1.0,
                    'coeff_other_cam_pred': 1.0,
                    'coeff_reward_pred': 1.0,

                    'coeff_human_pred': 1.0,
                    'coeff_obstructor_pred': 0.1,

                    # 'coeff_cur_depth_pred' : 5.0,
                    # 'coeff_next_depth_pred' : 5.0,
                }
            },

            # # == Post-processing LSTM ==
            "use_lstm": False,
            "max_seq_len": ROLLOUT_FRAGMENT_LENGTH,
        },

        'multiagent': MULTI_AGENT
    }
)
