import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from src.nn import TransformerBlock


class KVCache:
    def __init__(
            self,
            batch_size,
            max_seq_len,
            num_layers,
            num_heads,
            head_dim,
            device,
            dtype,
    ):
        # we assume that all layers and all samples in the batch are updating their cache simultaneously
        # and have equal sequence length, i.e. during evaluation on the vector environment
        self.cache_shape = (num_layers, batch_size, max_seq_len, num_heads, head_dim)
        self.k_cache = torch.full(self.cache_shape, fill_value=torch.nan, dtype=dtype, device=device).detach()
        self.v_cache = torch.full(self.cache_shape, fill_value=torch.nan, dtype=dtype, device=device).detach()
        self.cache_seqlens = 0

    def __len__(self):
        return self.k_cache.shape[0]

    def __getitem__(self, layer_idx):
        return self.k_cache[layer_idx], self.v_cache[layer_idx], self.cache_seqlens

    def reset(self):
        self.cache_seqlens = 0

    def update(self):
        self.cache_seqlens = self.cache_seqlens + 1
        if self.cache_seqlens == self.cache_shape[2]:
            self.k_cache = torch.roll(self.k_cache, -1, dims=2)
            self.v_cache = torch.roll(self.v_cache, -1, dims=2)
            self.cache_seqlens = self.cache_seqlens - 1
            assert self.cache_seqlens >= 0, "negative cache sequence length"
            # for debug purposes
            self.k_cache[:, :, -1] = torch.nan
            self.v_cache[:, :, -1] = torch.nan


class Transformer(nn.Module):
    def __init__(
            self,
            seq_len: int = 40,
            embedding_dim: int = 64,
            hidden_dim: int = 256,
            num_layers: int = 4,
            num_heads: int = 4,
            attention_dropout: float = 0.5,
            residual_dropout: float = 0.0,
            embedding_dropout: float = 0.1,
            normalize_qk: bool = False,
            pre_norm: bool = True,
    ):
        super().__init__()
        self.emb_drop = nn.Dropout(embedding_dropout)
        self.emb2hid = nn.Linear(embedding_dim, hidden_dim)

        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    hidden_dim=hidden_dim,
                    num_heads=num_heads,
                    attention_dropout=attention_dropout,
                    residual_dropout=residual_dropout,
                    normalize_qk=normalize_qk,
                    pre_norm=pre_norm,
                )
                for _ in range(num_layers)
            ]
        )
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_heads = num_heads

        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(module: nn.Module):
        # taken from the nanoGPT, may be not optimal
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def init_cache(self, batch_size, dtype, device):
        cache = KVCache(
            batch_size=batch_size,
            max_seq_len=self.seq_len,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            head_dim=self.hidden_dim // self.num_heads,
            device=device,
            dtype=dtype,
        )
        return cache

    def forward(self, sequence, cache: KVCache = None):
        _cache = cache or [(None, None, None) for _ in range(self.num_layers)]

        # [batch_size, seq_len, hidden_dim]
        sequence = self.emb2hid(sequence)

        out = self.emb_drop(sequence)
        for i, block in enumerate(self.blocks):
            out = block(out, *_cache[i])

        if cache is not None:
            cache.update()

        return out, cache


class ADTuples(nn.Module):
    def __init__(
            self,
            num_states: int,
            num_actions: int,
            num_params: int,
            seq_len: int = 200,
            hidden_dim: int = 256,
            num_layers: int = 4,
            num_heads: int = 4,
            attention_dropout: float = 0.5,
            residual_dropout: float = 0.0,
            embedding_dropout: float = 0.1,
            normalize_qk: bool = False,
            pre_norm: bool = True,
            continuous_states: bool = False,
            continuous_actions: bool = False,
            nonlinear_action_head: bool = False
    ):
        super().__init__()
        self.transformer = Transformer(
            seq_len=seq_len,
            embedding_dim=num_actions + num_states + 1,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            attention_dropout=attention_dropout,
            residual_dropout=residual_dropout,
            embedding_dropout=embedding_dropout,
            normalize_qk=normalize_qk,
            pre_norm=pre_norm,
        )
        self.continuous_states = continuous_states
        self.continuous_actions = continuous_actions

        if nonlinear_action_head:
            self.action_head = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.LeakyReLU(),
                nn.Linear(hidden_dim, num_actions)
            )
        else:
            self.action_head = nn.Linear(hidden_dim, num_actions)

        # self.params_head = nn.Sequential(
        #     nn.Linear(hidden_dim, 16),
        #     nn.LayerNorm(16),
        #     nn.LeakyReLU(),
        #     nn.Linear(16, num_params)
        # )

        self.num_states = num_states
        self.num_actions = num_actions
        self.seq_len = seq_len

    def init_cache(self, batch_size, dtype, device):
        return self.transformer.init_cache(batch_size, dtype, device)

    def forward(
            self,
            states,        # [batch_size, seq_len]
            prev_actions,  # [batch_size, seq_len]
            prev_rewards,  # [batch_size, seq_len]
            cache: KVCache = None,
            return_repr: bool = False,
    ):
        # you can use different encodings here, but I use the most simple one,
        # which works just fine for the dark-room, key-to-door
        if not self.continuous_states:
            state_emb = F.one_hot(states, num_classes=self.num_states)
        else:
            state_emb = states
        if not self.continuous_actions:
            action_emb = F.one_hot(prev_actions, num_classes=self.num_actions)
        else:
            action_emb = prev_actions
        reward_emb = prev_rewards.unsqueeze(-1)
        # [batch_size, seq_len, emb_dim * 3]
        sequence = torch.concatenate([action_emb, reward_emb, state_emb], dim=-1)
        embedding, cache = self.transformer(sequence, cache=cache)
        # [batch_size, seq_len, num_actions]
        out = self.action_head(embedding)
        if not return_repr:
            return out, cache
        # else:
        #     return out, cache, self.params_head(embedding), embedding


class PolicyHead(nn.Module):
    def __init__(self, num_actions, hidden_dim, use_ln):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            nn.LeakyReLU(),
        )
        self.mean = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, num_actions),
            nn.Tanh(),
        )
        self.log_std = nn.Parameter(torch.zeros(num_actions, dtype=torch.float32))

    def forward(self, embedding):
        emb = self.trunk(embedding)
        return self.mean(emb), self.log_std


class ADIQLTuples(nn.Module):
    def __init__(
            self,
            num_states: int,
            num_actions: int,
            seq_len: int = 200,
            hidden_dim: int = 256,
            num_layers: int = 4,
            num_heads: int = 4,
            attention_dropout: float = 0.5,
            residual_dropout: float = 0.0,
            embedding_dropout: float = 0.1,
            normalize_qk: bool = False,
            pre_norm: bool = True,
            continuous_states: bool = False,
            continuous_actions: bool = False,
            use_ln: bool = False,
            detach_pi: bool = True,
            detach_v: bool = False,
    ):
        super().__init__()
        self.continuous_states = continuous_states
        self.continuous_actions = continuous_actions
        self.detach_pi = detach_pi
        self.detach_v = detach_v
        self.detach_base = False
        self.transformer = Transformer(
            seq_len=seq_len,
            embedding_dim=num_actions + num_states + 1 + 1 + 1,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            attention_dropout=attention_dropout,
            residual_dropout=residual_dropout,
            embedding_dropout=embedding_dropout,
            normalize_qk=normalize_qk,
            pre_norm=pre_norm,
        )
        self.q_head_1 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim) if not continuous_states else nn.Linear(hidden_dim + num_actions, hidden_dim),
            nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            nn.LeakyReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            # nn.LeakyReLU(),
            nn.Linear(hidden_dim, num_actions) if not continuous_states else nn.Linear(hidden_dim, 1),
        )
        self.q_head_2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim) if not continuous_states else nn.Linear(hidden_dim + num_actions, hidden_dim),
            nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            nn.LeakyReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            # nn.LeakyReLU(),
            nn.Linear(hidden_dim, num_actions) if not continuous_states else nn.Linear(hidden_dim, 1),
        )

        self.q_head_target_1 = copy.deepcopy(self.q_head_1).requires_grad_(False)
        self.q_head_target_2 = copy.deepcopy(self.q_head_2).requires_grad_(False)

        self.v_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            nn.LeakyReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.LayerNorm(hidden_dim) if use_ln else nn.Identity(),
            # nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1),
        )

        self.pi = PolicyHead(num_actions, hidden_dim, use_ln)
        self.pi_target = copy.deepcopy(self.pi).requires_grad_(False)

        self.num_states = num_states
        self.num_actions = num_actions
        self.seq_len = seq_len

    def init_cache(self, batch_size, dtype, device):
        return self.transformer.init_cache(batch_size, dtype, device)

    def freeze_base(self):
        for layer in [self.transformer]:
            for param in layer.parameters():
                param.requires_grad = False
        self.detach_base = True

    def forward(
            self,
            states,        # [batch_size, seq_len]
            prev_actions,  # [batch_size, seq_len]
            prev_rewards,  # [batch_size, seq_len]
            prev_dones,    # [batch_size, seq_len]
            steps,         # [batch_size, seq_len]
            actions=None,  # [batch_size, seq_len]
            cache: KVCache = None,
    ):
        # you can use different encodings here, but I use the most simple one,
        # which works just fine for the dark-room, key-to-door
        if not self.continuous_states:
            state_emb = F.one_hot(states, num_classes=self.num_states)
        else:
            state_emb = states
        if not self.continuous_actions:
            action_emb = F.one_hot(prev_actions, num_classes=self.num_actions)
        else:
            action_emb = prev_actions
        reward_emb = prev_rewards.unsqueeze(-1)
        prev_dones = prev_dones.unsqueeze(-1)
        steps = steps.unsqueeze(-1)

        # [batch_size, seq_len, emb_dim * 4]
        sequence = torch.concatenate([action_emb, reward_emb, prev_dones, steps, state_emb], dim=-1)
        # print("SEQ DT", sequence.dtype, cache, flush=True)
        embedding, cache = self.transformer(sequence, cache=cache)
        if self.detach_base:
            embedding = embedding.detach()
        states_input = embedding
        if actions is None:
            actions = prev_actions
        if self.continuous_actions:
            states_input = torch.concatenate([states_input, actions], dim=-1)
        # [batch_size, seq_len, num_actions]
        q1 = self.q_head_1(states_input)
        q2 = self.q_head_2(states_input)
        q1_target = self.q_head_target_1(states_input)
        q2_target = self.q_head_target_2(states_input)
        if not self.detach_v:
            v = self.v_head(embedding)
        else:
            v = self.v_head(embedding.detach())
        actions_input = embedding
        if self.detach_pi:
            actions_input = actions_input.detach()
        mean, log_std = self.pi(actions_input)
        std = torch.exp(log_std.clamp(-20, 2.0))
        pred_actions = Normal(mean, std)

        target_mean, target_log_std = self.pi_target(actions_input)
        target_std = torch.exp(target_log_std.clamp(-20, 2.0))
        target_pred_actions = Normal(target_mean, target_std)

        return q1, q2, q1_target, q2_target, v, (pred_actions, target_pred_actions), cache

    def update_targets(self, tau):
        for target_param, source_param in zip(self.q_head_target_1.parameters(), self.q_head_1.parameters()):
            target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
        for target_param, source_param in zip(self.q_head_target_2.parameters(), self.q_head_2.parameters()):
            target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
        for target_param, source_param in zip(self.pi_target.parameters(), self.pi.parameters()):
            target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
