from copy import deepcopy
from typing import Dict

import h5py
import torch

from torch import nn

from transfer.envs.metaworld import MW_ACT_LEN


def custom_leaky_relu():
    return nn.LeakyReLU(0.2)


def _h5_to_torch(data):
    return torch.tensor(data[:])

def get_last_q_layer_name(ac):
    param_names = [name for name, param in ac.q1.named_parameters()]
    layer_idxs = list(int(name.split(".")[1]) for name in param_names)
    last_layer_idx = max(layer_idxs)
    last_layer_name = next(name for name in param_names if name.startswith(f"q.{last_layer_idx}"))
    return last_layer_name


def _sequential_dict_keras_to_torch(
    keras_weights: Dict, layers_num: int, dense_start_idx: int = 0, layer_norm: bool = False, head: Dict = None
):
    torch_weights = {}

    for linear_layer_idx in range(layers_num):
        module_idx = linear_layer_idx * 2
        if layer_norm and linear_layer_idx > 0:  # shift because of the layernorm
            module_idx += 1

        keras_dense_idx = dense_start_idx + linear_layer_idx
        keras_layer_name = "dense" if keras_dense_idx == 0 else f"dense_{keras_dense_idx}"

        if linear_layer_idx == layers_num - 1 and head is not None:
            torch_weights[f"{module_idx}.weight"] = _h5_to_torch(head[keras_layer_name]["kernel:0"]).T
            torch_weights[f"{module_idx}.bias"] = _h5_to_torch(head[keras_layer_name]["bias:0"])
        else:
            torch_weights[f"{module_idx}.weight"] = _h5_to_torch(keras_weights[keras_layer_name]["kernel:0"]).T
            torch_weights[f"{module_idx}.bias"] = _h5_to_torch(keras_weights[keras_layer_name]["bias:0"])

    if layer_norm:
        # the name might be 'layer_normalization' or 'layer_normalization_1' and so on
        layer_norm_key = list(key for key in keras_weights.keys() if "layer_normalization" in key)[0]
        torch_weights["1.weight"] = _h5_to_torch(keras_weights[layer_norm_key]["gamma:0"])
        torch_weights["1.bias"] = _h5_to_torch(keras_weights[layer_norm_key]["beta:0"])

    return torch_weights


def load_tf_weights(ac, weights_path, use_layer_norm):
    # Actor
    actor_keras_dict = h5py.File(weights_path + "/actor.h5")
    actor_core_dict = _sequential_dict_keras_to_torch(actor_keras_dict["sequential"], 4, layer_norm=True)
    ac.pi.net.load_state_dict(actor_core_dict)

    ac.pi.mu_layer.load_state_dict(
        {
            "weight": _h5_to_torch(actor_keras_dict["sequential_1"]["dense_4"]["kernel:0"]).T,
            "bias": _h5_to_torch(actor_keras_dict["sequential_1"]["dense_4"]["bias:0"]),
        }
    )
    ac.pi.log_std_layer.load_state_dict(
        {
            "weight": _h5_to_torch(actor_keras_dict["sequential_2"]["dense_5"]["kernel:0"]).T,
            "bias": _h5_to_torch(actor_keras_dict["sequential_2"]["dense_5"]["bias:0"]),
        }
    )

    critic1_keras_dict = h5py.File(weights_path + "/critic1.h5")
    critic1_dict = _sequential_dict_keras_to_torch(
        critic1_keras_dict["sequential_3"],
        5,
        layer_norm=True,
        dense_start_idx=6,
        head=critic1_keras_dict["sequential_4"],
    )
    ac.q1.q.load_state_dict(critic1_dict)

    critic2_keras_dict = h5py.File(weights_path + "/critic2.h5")
    critic2_dict = _sequential_dict_keras_to_torch(
        critic2_keras_dict["sequential_7"],
        5,
        layer_norm=True,
        dense_start_idx=16,
        head=critic2_keras_dict["sequential_8"],
    )
    ac.q2.q.load_state_dict(critic2_dict)

    ac_targ = deepcopy(ac)

    critic1_keras_dict = h5py.File(weights_path + "/target_critic1.h5")
    critic1_dict = _sequential_dict_keras_to_torch(
        critic1_keras_dict["sequential_5"],
        5,
        layer_norm=True,
        dense_start_idx=11,
        head=critic1_keras_dict["sequential_6"],
    )
    ac_targ.q1.q.load_state_dict(critic1_dict)

    critic2_keras_dict = h5py.File(weights_path + "/target_critic2.h5")
    critic2_dict = _sequential_dict_keras_to_torch(
        critic2_keras_dict["sequential_9"],
        5,
        layer_norm=True,
        dense_start_idx=21,
        head=critic2_keras_dict["sequential_10"],
    )
    ac_targ.q2.q.load_state_dict(critic2_dict)

    return ac, ac_targ

def remove_heads(checkpoint, num_heads, num_heads_ckpt):
    # Add new outputs to policy heads
    for head_name in ["pi.mu_layer", "pi.log_std_layer"]:
        head_ckpt = checkpoint["ac"][f"{head_name}.weight"]
        head_ckpt = head_ckpt.view(MW_ACT_LEN, num_heads_ckpt, -1)
        hidden_dim = head_ckpt.shape[-1]
        head_ckpt = head_ckpt[:, :num_heads].reshape(-1, hidden_dim)

        checkpoint["ac"][f"{head_name}.weight"] = head_ckpt.clone()
        checkpoint["ac_targ"][f"{head_name}.weight"] = head_ckpt.clone()

        bias_ckpt = checkpoint["ac"][f"{head_name}.bias"]
        bias_ckpt = bias_ckpt.view(MW_ACT_LEN, num_heads_ckpt)
        bias_ckpt = bias_ckpt[:, :num_heads].reshape(-1)
        checkpoint["ac"][f"{head_name}.bias"] = bias_ckpt.clone()
        checkpoint["ac_targ"][f"{head_name}.bias"] = bias_ckpt.clone()

    # Add new outputs to Q heads
    for head_name in ["q1.q.9", "q2.q.9"]:
        head_ckpt = checkpoint["ac"][f"{head_name}.weight"][:num_heads, :]
        checkpoint["ac"][f"{head_name}.weight"] = head_ckpt.clone()
        checkpoint["ac_targ"][f"{head_name}.weight"] = ( 
            checkpoint["ac_targ"][f"{head_name}.weight"][:num_heads].clone()
        )

        bias_ckpt = checkpoint["ac"][f"{head_name}.bias"][:num_heads]
        checkpoint["ac"][f"{head_name}.bias"] = bias_ckpt.clone()
        checkpoint["ac_targ"][f"{head_name}.bias"] = bias_ckpt.clone()

    checkpoint["ac"]["all_log_alpha"] = checkpoint["ac"]["all_log_alpha"][:num_heads]
    checkpoint["ac_targ"]["all_log_alpha"] = checkpoint["ac"]["all_log_alpha"].clone()
    return checkpoint

def append_new_heads(checkpoint, num_heads, num_heads_ckpt):
    heads_to_add = num_heads - num_heads_ckpt

    # Add new outputs to policy heads
    for head_name in ["pi.mu_layer", "pi.log_std_layer"]:
        head_ckpt = checkpoint["ac"][f"{head_name}.weight"]
        head_ckpt = head_ckpt.view(MW_ACT_LEN, num_heads_ckpt, -1)
        hidden_dim = head_ckpt.shape[-1]

        params_to_add = torch.zeros(MW_ACT_LEN, heads_to_add, hidden_dim)
        torch.nn.init.xavier_uniform_(params_to_add)
        head_extended = torch.cat([head_ckpt, params_to_add], dim=1).view(-1, hidden_dim)
        checkpoint["ac"][f"{head_name}.weight"] = head_extended
        checkpoint["ac_targ"][f"{head_name}.weight"] = head_extended.clone()

        bias_ckpt = checkpoint["ac"][f"{head_name}.bias"]
        bias_ckpt = bias_ckpt.view(MW_ACT_LEN, num_heads_ckpt)
        bias_to_add = torch.zeros(MW_ACT_LEN, heads_to_add)
        bias_extended = torch.cat([bias_ckpt, bias_to_add], dim=1).view(-1)
        print(head_name, head_extended.shape, bias_extended.shape)
        checkpoint["ac"][f"{head_name}.bias"] = bias_extended
        checkpoint["ac_targ"][f"{head_name}.bias"] = bias_extended.clone()

    # Add new outputs to Q heads
    for head_name in ["q1.q.9", "q2.q.9"]:
        head_ckpt = checkpoint["ac"][f"{head_name}.weight"]
        hidden_dim = head_ckpt.shape[-1]

        params_to_add = torch.zeros(heads_to_add, hidden_dim)
        torch.nn.init.xavier_uniform_(params_to_add)
        head_extended = torch.cat([head_ckpt, params_to_add], dim=0)
        checkpoint["ac"][f"{head_name}.weight"] = head_extended
        checkpoint["ac_targ"][f"{head_name}.weight"] = torch.cat(
            [checkpoint["ac_targ"][f"{head_name}.weight"], params_to_add], dim=0
        )

        bias_ckpt = checkpoint["ac"][f"{head_name}.bias"]
        bias_to_add = torch.zeros(heads_to_add)
        bias_extended = torch.cat([bias_ckpt, bias_to_add], dim=0)
        print(head_name, head_extended.shape, bias_extended.shape)
        checkpoint["ac"][f"{head_name}.bias"] = bias_extended
        checkpoint["ac_targ"][f"{head_name}.bias"] = torch.cat(
            [checkpoint["ac_targ"][f"{head_name}.bias"], bias_to_add], dim=0
        )

    checkpoint["ac"]["all_log_alpha"] = torch.cat(
        [checkpoint["ac"]["all_log_alpha"], torch.ones(heads_to_add, 1)], dim=0
    )
    checkpoint["ac_targ"]["all_log_alpha"] = checkpoint["ac"]["all_log_alpha"].clone()
    return checkpoint
