import os
import torch

from .base_game import BaseConfig


class MuZeroConfig(BaseConfig):
    def __init__(self, output_folder_name=""):
        super().__init__(output_folder_name=output_folder_name, game_name=os.path.basename(__file__)[:-3])

        self.max_step = 91
        self.image_size = 300

        # ----- Dataset config -----#
        self.predictions_path = "data/cityengine_segmetnation_predictions/"
        self.base_dataset_dir = "data/cityengine/"

        self.game_config = {
            "dataset_type": "cityengine",
            "dataset_config": {
                "base_dir": self.base_dataset_dir,
                "mode": "train",
                "new_image_size": self.image_size,
                "max_num_vertices": 60,
                "max_num_lines": 120,
                "augment": True,
                "load_segmentation": self.predictions_path,
                "crop_size": 400,
                "insert_intermediate_vertices_probs": 1,
                "min_distance_between_intermediate_vertices": 25,
                "num_intermediate_vertices": (4, 5),
                "return_tree_mask": False,
            },
            "max_step": self.max_step,
            "image_size": self.image_size,
            "reward_weights": [0.2, 0.35, 0.15, 0.15, 0.15],
            "initialize_to_pred_graph": True,
            "min_linestrings_length": 60,
            "min_connected_components_length": 100,
        }
        # parameters that change for the evaluation dataset
        self.evaluation_datasets = {
            "val": {
                "mode": "val",
                "augment": False,
            }
        }

        # ----- Model config -----#
        self.network = "enhanced-keypoints-tokens"
        self.use_consistency_loss = False

        self.model_config = {
            "encoder_config": {
                "hidden_size": 256,
                "fc_size": 256,
                "num_layers": 16,
                "layer_norm": True,
                "dropout_rate": 0.15,
            },
            "decoder_config": {
                "hidden_size": 256,
                "fc_size": 256,
                "num_layers": 16,
                "layer_norm": True,
                "dropout_rate": 0.15,
            },
            "resnet_size": 18,
            "embedding_dim": 256,
            "num_input_channels": 3,
            "image_features": 960,  # 3840,
            "image_features_final_dimension": 256,
            "image_size": self.image_size,
            "use_discrete_embeddings": True,
            "fix_backbone": False,
            "max_seq_length": 120 * 2 + 1,
            "use_postion_embeddings": True,
            "max_step": self.max_step,
            "support_size": self.support_size,
            "checkpoint": "/local/home/sanagnos/reasoning_rl/model/synthetic_enhanced_random/autoregressive_model.pkl",
            "add_predictor": self.use_consistency_loss,  # according to https://arxiv.orgepdf/2111.00210.pdf use the latent space directly without projecting ..
            "cls_token": False,
        }

        # ----- Training config -----#
        self.num_workers = 10  # Number of simultaneous threads/workers self-playing to feed the replay buffer
        self.max_moves = (
            self.max_step
        )  # Maximum number of moves, deprecated, the environment max step considered instead
        self.num_simulations = 50  # Number of future moves self-simulated
        self.save_model = True
        self.training_steps = int(300e3)
        self.batch_size = 200
        self.checkpoint_interval = int(1e3)
        self.value_loss_weight = 25  # 15
        self.reward_loss_weight = 70  # 50

        self.consistency_loss_weight = 0
        if not self.use_consistency_loss:
            assert self.consistency_loss_weight == 0

        self.optimizer = "AdamW"
        self.weight_decay = 1e-4
        self.momentum = 0.9
        self.epsilon = 0.02
        self.alpha = 0.99

        # Exponential learning rate schedule
        self.lr_init = 0.002  # Initial learning rate
        self.lr_decay_rate = 0.25  # Set it to 1 to use a constant learning rate
        self.lr_decay_steps = 100e3

        # ----- Replay Buffer config -----#
        self.replay_buffer_size = int(2e4)
        self.num_unroll_steps = 5
        self.td_steps = 5
        self.PER = True
        self.PER_alpha = 1  # How much prioritization is used, 0 corresponding to the uniform case, paper suggests 1
        self.slow_parameter_update_weight = 0.02

        self.value_loss = "mse"
        self.reward_loss = "mse"

        self.end_reward_percentage_steps = []
        self.end_reward_percentage_values = [0.0]
        assert len(self.end_reward_percentage_steps) + 1 == len(
            self.end_reward_percentage_values
        )

        self.negative_reward_steps = []
        self.negative_reward_values = [1.0]
        assert len(self.negative_reward_steps) + 1 == len(self.negative_reward_values)

        self.temperature_steps = [180000, 240000]
        self.temperature_values = [1.0, 0.5, 0.25]
        assert len(self.temperature_steps) + 1 == len(self.temperature_values)

    def get_end_reward_percentage(self, training_step):
        for step, val in zip(
            self.end_reward_percentage_steps, self.end_reward_percentage_values
        ):
            if training_step < step:
                return val
        return self.end_reward_percentage_values[-1]

    def get_negative_reward_percentage(self, training_step):
        for step, val in zip(self.negative_reward_steps, self.negative_reward_values):
            if training_step < step:
                return val
        return self.negative_reward_values[-1]

    def visit_softmax_temperature_fn(self, training_step):
        for step, val in zip(self.temperature_steps, self.temperature_values):
            if training_step < step:
                return val

        return self.temperature_values[-1]
