import copy
import time

import numpy
import ray
import torch

from models import muzero_models
import fcntl, hashlib

cosine_similarity_loss_fn = torch.nn.CosineSimilarity(dim=1)


class SystemMutex:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        lock_id = hashlib.md5(self.name.encode("utf8")).hexdigest()
        self.fp = open(f"/tmp/.lock-{lock_id}.lck", "wb")
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX)

    def __exit__(self, _type, value, tb):
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN)
        self.fp.close()


@ray.remote
class Trainer:
    """
    Class which run in a dedicated thread to train a neural network and save it
    in the shared storage.
    """

    def __init__(
        self,
        initial_checkpoint,
        config,
        device=None,
        skip_load_optimizer=False,
    ):
        self.config = config

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        # Initialize the network
        self.model = muzero_models.MuZeroLinesModel(
            self.config.network, **self.config.model_config
        )
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        parameters = self.model.get_parameters()

        if torch.cuda.device_count() > 1:
            print("Trainer has", torch.cuda.device_count(), "GPUs in total!")
            self.multi_gpu_training = True
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
        else:
            self.multi_gpu_training = False
            self.model.to(device)

        print("Starting trainer, cuda: ", torch.cuda.is_available())
        self.model.train()

        self.training_step = initial_checkpoint["training_step"]

        if "cuda" not in str(next(self.model.parameters()).device):
            print("You are not training on GPU.\n")

        # Initialize the optimizer
        if self.config.optimizer == "SGD":
            self.optimizer = torch.optim.SGD(
                parameters,
                lr=self.config.lr_init,
                momentum=self.config.momentum,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "Adam":
            self.optimizer = torch.optim.Adam(
                parameters,
                lr=self.config.lr_init,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "AdamW":
            self.optimizer = torch.optim.AdamW(
                parameters,
                lr=self.config.lr_init,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "RMSprop":
            self.optimizer = torch.optim.RMSprop(
                parameters,
                lr=self.config.lr_init,
                momentum=self.config.momentum,
                eps=self.config.epsilon,
                alpha=self.config.alpha,
            )
        else:
            raise NotImplementedError(
                f"{self.config.optimizer} is not implemented. You can change the optimizer manually in trainer.py."
            )

        if skip_load_optimizer:
            print("Skipping loading optimizer dict")
        else:
            if initial_checkpoint["optimizer_state"] is not None:
                print("Loading optimizer...\n")
                try:
                    # make sure that the same optimizer is used?
                    self.optimizer.load_state_dict(
                        copy.deepcopy(initial_checkpoint["optimizer_state"])
                    )
                except ValueError as e:
                    print("Error occured while loading optimizer")
                    print(e)
                    print("Continue without loading it..")

        loss_functions = {"cel": self.cel_loss_function, "mse": self.mse_loss_function}

        assert self.config.value_loss in list(loss_functions.keys())
        assert self.config.reward_loss in list(loss_functions.keys())

        self.value_loss_fn = loss_functions[self.config.value_loss]
        self.reward_loss_fn = loss_functions[self.config.reward_loss]

    def continuous_update_weights(self, replay_buffer, shared_storage):
        previous_negative_reward_percentage = (
            self.config.get_negative_reward_percentage(self.training_step)
        )
        previous_end_reward_percentage = self.config.get_end_reward_percentage(
            self.training_step
        )
        replay_buffer.fluss_buffer.remote(
            shared_storage,
            previous_negative_reward_percentage,
            previous_end_reward_percentage,
        )
        time.sleep(100)

        # Wait for the replay buffer to be filled
        while ray.get(shared_storage.get_info.remote("replay_buffer_size")) < 20:
            time.sleep(0.1)

        next_batch = replay_buffer.get_batch.remote(shared_storage)
        # Training loop
        while self.training_step < self.config.training_steps and not ray.get(
            shared_storage.get_info.remote("terminate")
        ):
            index_batch, batch = ray.get(next_batch)
            next_batch = replay_buffer.get_batch.remote(shared_storage)

            self.update_lr()
            (
                priorities,
                total_loss,
                value_loss,
                reward_loss,
                policy_loss,
                consistency_loss,
            ) = self.update_weights(batch)

            if self.config.PER:
                # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933)
                replay_buffer.update_priorities.remote(priorities, index_batch)

            # Save to the shared storage
            if self.training_step % self.config.checkpoint_interval == 0:
                if self.multi_gpu_training:
                    weights = copy.deepcopy(self.model.module.get_weights())
                else:
                    weights = copy.deepcopy(self.model.get_weights())

                shared_storage.set_info.remote(
                    {
                        "weights": weights,
                        "optimizer_state": copy.deepcopy(
                            muzero_models.dict_to_cpu(self.optimizer.state_dict())
                        ),
                    }
                )
                if self.config.save_model:
                    shared_storage.save_checkpoint.remote()

            if self.optimizer.param_groups[0]["type"] == "slow":
                slow_lr = self.optimizer.param_groups[0]["lr"]
                fast_lr = self.optimizer.param_groups[1]["lr"]
            else:
                slow_lr = self.optimizer.param_groups[1]["lr"]
                fast_lr = self.optimizer.param_groups[0]["lr"]

            shared_storage.set_info.remote(
                {
                    "training_step": self.training_step,
                    "slow_lr": slow_lr,
                    "fast_lr": fast_lr,
                    "total_loss": total_loss,
                    "value_loss": value_loss,
                    "reward_loss": reward_loss,
                    "policy_loss": policy_loss,
                    "consistency_loss": consistency_loss,
                }
            )

            # Managing the self-play / training ratio
            if self.config.training_delay:
                time.sleep(self.config.training_delay)
            if self.config.ratio:
                while (
                    self.training_step
                    / max(
                        1, ray.get(shared_storage.get_info.remote("num_played_steps"))
                    )
                    > self.config.ratio
                    and self.training_step < self.config.training_steps
                    and not ray.get(shared_storage.get_info.remote("terminate"))
                ):
                    time.sleep(0.5)

            new_negative_reward_percentage = self.config.get_negative_reward_percentage(
                self.training_step
            )
            new_end_reward_percentage = self.config.get_end_reward_percentage(
                self.training_step
            )
            if (
                new_negative_reward_percentage != previous_negative_reward_percentage
                or new_end_reward_percentage != previous_end_reward_percentage
            ):

                previous_negative_reward_percentage = new_negative_reward_percentage
                previous_end_reward_percentage = new_end_reward_percentage
                replay_buffer.fluss_buffer.remote(
                    shared_storage,
                    new_negative_reward_percentage,
                    new_end_reward_percentage,
                )
                time.sleep(100)

                while (
                    ray.get(shared_storage.get_info.remote("replay_buffer_size")) < 20
                ):
                    time.sleep(0.1)

                next_batch = replay_buffer.get_batch.remote(shared_storage)

    def update_weights(self, batch):
        """
        Perform one training step.
        """

        if self.config.use_consistency_loss:
            (
                observation_batch,
                observation_at_td_steps_batch,
                action_batch,
                target_value,
                target_reward,
                target_policy,
                policy_masks_batch,
                weight_batch,
                gradient_scale_batch,
            ) = batch
        else:
            (
                observation_batch,
                action_batch,
                target_value,
                target_reward,
                target_policy,
                policy_masks_batch,
                weight_batch,
                gradient_scale_batch,
            ) = batch

        # Keep values as scalars for calculating the priorities for the prioritized replay
        target_value_scalar = numpy.array(target_value, dtype="float32")
        priorities = numpy.zeros_like(target_value_scalar)

        device = next(self.model.parameters()).device

        if self.config.PER:
            weight_batch = torch.tensor(weight_batch.copy()).float().to(device)
        observation_batch = {
            k: torch.tensor(v).to(device) for k, v in observation_batch.items()
        }
        if self.config.use_consistency_loss:
            observation_at_td_steps_batch = {
                k: torch.tensor(v).to(device)
                for k, v in observation_at_td_steps_batch.items()
            }
        else:
            observation_at_td_steps_batch = None

        action_batch = torch.tensor(action_batch).long().to(device).unsqueeze(-1)
        target_value = torch.tensor(target_value).float().to(device)
        target_reward = torch.tensor(target_reward).float().to(device)
        target_policy = torch.tensor(target_policy).float().to(device)
        policy_masks_batch = torch.tensor(policy_masks_batch).float().to(device)
        gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(device)
        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
        # target_value: batch, num_unroll_steps+1 w
        # target_reward: batch, num_unroll_steps+1l
        # target_policy: batch, num_unroll_steps+1, len(action_space)
        # gradient_scale_batch: batch, num_unroll_steps+1

        target_value = muzero_models.scalar_to_support(
            target_value,
            self.config.support_size,
            self.config.support_scaling_factor_value,
        )
        target_reward = muzero_models.scalar_to_support(
            target_reward,
            self.config.support_size,
            self.config.support_scaling_factor_reward,
        )
        # target_value: batch, num_unroll_steps+1, 2*support_size+1
        # target_reward: batch, num_unroll_steps+1, 2*support_size+1

        predictions, consistency_loss_per_sample, size_per_sample = self.model(
            observation_batch,
            action_batch,
            observation_at_td_steps_batch=observation_at_td_steps_batch,
            use_consistency_loss=self.config.use_consistency_loss,
        )

        consistency_loss = torch.sum(
            consistency_loss_per_sample * size_per_sample
        ) / torch.sum(size_per_sample)

        # Compute losses
        value_loss, reward_loss, policy_loss = (0, 0, 0)
        value, reward, policy_logits = predictions[0]
        # Ignore reward loss for the first batch step
        current_value_loss, _, current_policy_loss = self.loss_function(
            value.squeeze(-1),
            reward.squeeze(-1),
            policy_logits,
            target_value[:, 0],
            target_reward[:, 0],
            target_policy[
                :, 0, : policy_logits.shape[1]
            ],  # the rest should be just zero padded
        )
        value_loss += current_value_loss * policy_masks_batch[:, 0]
        policy_loss += current_policy_loss * policy_masks_batch[:, 0]
        # Compute priorities for the prioritized replay (See paper appendix Training)
        pred_value_scalar = (
            muzero_models.support_to_scalar(
                value,
                self.config.support_size,
                self.config.support_scaling_factor_value,
            )
            .detach()
            .cpu()
            .numpy()
            .squeeze()
        )
        priorities[:, 0] = (
            numpy.abs(pred_value_scalar - target_value_scalar[:, 0])
            ** self.config.PER_alpha
        )

        for i in range(1, len(predictions)):
            value, reward, policy_logits = predictions[i]
            (
                current_value_loss,
                current_reward_loss,
                current_policy_loss,
            ) = self.loss_function(
                value.squeeze(-1),
                reward.squeeze(-1),
                policy_logits,
                target_value[:, i],
                target_reward[:, i],
                target_policy[:, i, : policy_logits.shape[1]],
            )

            # Scale gradient by the number of unroll steps
            current_value_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )
            current_reward_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )
            current_policy_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )

            value_loss += current_value_loss
            reward_loss += current_reward_loss
            policy_loss += current_policy_loss * policy_masks_batch[:, i]

            # Compute priorities for the prioritized replay
            pred_value_scalar = (
                muzero_models.support_to_scalar(
                    value,
                    self.config.support_size,
                    self.config.support_scaling_factor_value,
                )
                .detach()
                .cpu()
                .numpy()
                .squeeze()
            )
            priorities[:, i] = (
                numpy.abs(pred_value_scalar - target_value_scalar[:, i])
                ** self.config.PER_alpha
            )

        loss = (
            value_loss * self.config.value_loss_weight
            + reward_loss * self.config.reward_loss_weight
            + consistency_loss * self.config.consistency_loss_weight
            + policy_loss
        )
        if self.config.PER:
            # Correct PER bias by using importance-sampling (IS) weights
            loss *= weight_batch
        loss = loss.mean()

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.training_step += 1

        return_consistency_loss = (
            consistency_loss.mean().item() if self.config.use_consistency_loss else 0
        )

        return (
            priorities,
            # For log purpose
            loss.item(),
            value_loss.mean().item(),
            reward_loss.mean().item(),
            policy_loss.mean().item(),
            return_consistency_loss,
        )

    def update_lr(self):
        """
        Update learning rate
        """
        lr = self.config.lr_init * self.config.lr_decay_rate ** (
            self.training_step / self.config.lr_decay_steps
        )
        for param_group in self.optimizer.param_groups:
            if param_group["type"] == "fast":
                param_group["lr"] = lr
            elif param_group["type"] == "slow":
                param_group["lr"] = lr * self.config.slow_parameter_update_weight
            else:
                raise NotImplementedError

    def loss_function(
        self,
        value,
        reward,
        policy_logits,
        target_value,
        target_reward,
        target_policy,
    ):
        value_loss = self.value_loss_fn(target_value, value)
        reward_loss = self.reward_loss_fn(target_reward, reward)
        policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(
            1
        )
        return value_loss, reward_loss, policy_loss

    def cel_loss_function(self, target, prediction):
        return (-target * torch.nn.LogSoftmax(dim=1)(prediction)).sum(1)

    def mse_loss_function(self, target, prediction):
        return ((target - torch.nn.Softmax(dim=1)(prediction)) ** 2).sum(1)
        return ((target - torch.nn.Softmax(dim=1)(prediction)) ** 2).sum(1)
