import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
from torch.distributions import Categorical
from algorithms.utils.util import check, init


class SelfAttention(nn.Module):

    def __init__(self, n_embd, n_head, n_agent, masked=False):
        super(SelfAttention, self).__init__()

        assert n_embd % n_head == 0
        self.masked = masked
        self.n_head = n_head
        # key, query, value projections for all heads
        self.key = init_(nn.Linear(n_embd, n_embd))
        self.query = init_(nn.Linear(n_embd, n_embd))
        self.value = init_(nn.Linear(n_embd, n_embd))
        # output projection
        self.proj = init_(nn.Linear(n_embd, n_embd))
        # if self.masked:
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(n_agent + 1, n_agent + 1))
                             .view(1, 1, n_agent + 1, n_agent + 1))

        self.att_bp = None

    def forward(self, key, value, query):
        B, L, D = query.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(key).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        q = self.query(query).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        v = self.value(value).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)

        # causal attention: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # self.att_bp = F.softmax(att, dim=-1)

        if self.masked:
            att = att.masked_fill(self.mask[:, :, :L, :L] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v  # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
        y = y.transpose(1, 2).contiguous().view(B, L, D)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)
        return y


class EncodeBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent, masked=True):
        super(EncodeBlock, self).__init__()

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        # self.attn = SelfAttention(n_embd, n_head, n_agent, masked=True)
        self.attn = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.mlp = nn.Sequential(
            init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(1 * n_embd, n_embd))
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x, x, x))
        x = self.ln2(x + self.mlp(x))
        return x


class ClassifierActEncoder(nn.Module):
    def __init__(self, n_block, n_embd, n_head, n_agent, masked):
        super(ClassifierActEncoder, self).__init__()
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.ln = nn.LayerNorm(n_embd)
        self.blocks = nn.Sequential(*[EncodeBlock(
            n_embd, n_head, n_agent, masked
        ) for _ in range(n_block)])
        # delete critic head to train off line
        # self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                           init_(nn.Linear(n_embd, 1)))

    def forward(self, act):
        rep = self.blocks(self.ln(act))

        return rep


class AgentClassifier(nn.Module):
    def __init__(self, state_dim, obs_dim, action_dim, n_embd, n_agent, action_type, device,
                 classifier_only_action=False, classifier_use_gru=False,
                 classifier_gru_his_len=10, classifier_gru_num_layer=2,
                 classifier_use_act_enc=False, classifier_act_enc_mask=False,
                 classifier_use_data_tag=False, classifier_data_tag_num=3,
                 classifier_enc_n_block=1, classifier_enc_n_head=1
                 ):
        super().__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.action_type = action_type
        self.device = device
        # only use action to train classifier
        self.classifier_only_action = classifier_only_action
        # use gru to train multi steps actions
        self.classifier_use_gru = classifier_use_gru
        self.classifier_gru_his_len = classifier_gru_his_len
        self.classifier_gru_num_layer = classifier_gru_num_layer
        # use encoder to train agents-wise classifier
        self.classifier_use_act_enc = classifier_use_act_enc
        self.classifier_act_enc_mask = classifier_act_enc_mask
        self.classifier_use_data_tag = classifier_use_data_tag
        self.classifier_data_tag_num = classifier_data_tag_num if classifier_use_data_tag else n_agent
        # use single state and union single as classifier mlp input
        self.net_in_dim = n_embd * 2 if not classifier_only_action else n_embd
        self.inner_model = nn.Sequential(
            nn.Linear(self.net_in_dim, n_embd),
            nn.Tanh(),
            nn.Linear(n_embd, n_embd),
            nn.Tanh(),
            nn.Linear(n_embd, self.classifier_data_tag_num),
        ) if not self.classifier_use_gru else nn.GRU(
            input_size=self.n_embd, hidden_size=self.n_embd, num_layers=classifier_gru_num_layer, batch_first=True
        )
        # end_of_token: (1, 1, obs_dim) or (1, 1, 1, obs_dim)
        self.end_of_token = nn.Parameter(
            torch.zeros(self.n_embd, dtype=torch.float32).unsqueeze(0).unsqueeze(0) if not self.classifier_use_act_enc else
            torch.zeros(self.n_embd, dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(0),
            requires_grad=True,
        ) if self.classifier_use_gru else None
        self.agent_encoder = ClassifierActEncoder(
            n_block=classifier_enc_n_block, n_embd=n_embd,
            n_head=classifier_enc_n_head, n_agent=n_agent, masked=self.classifier_act_enc_mask,
        ) if self.classifier_use_act_enc else None
        # add a output head after gru or encoder
        self.classifier_head = nn.Linear(n_embd, self.classifier_data_tag_num) if self.classifier_use_gru or self.classifier_use_act_enc else None
        self.apply(init_disc)
        # obs encoder for
        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim), init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU()) \
            if not classifier_only_action else None
        # action encoder
        if action_type == 'Discrete':
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd, bias=False), activate=True), nn.GELU())
        else:
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU())
        self.action_ln = nn.LayerNorm(n_embd)
        # copy MAT to current device
        self.to(device)

    def forward(self, obs, actions):
        # mlp disc not use encoder to process obs
        obs = self.obs_encoder(obs) if not self.classifier_only_action else None
        # change action to onehot(if not pre-onehot before forward) before feed into action encoder
        if self.action_type == 'Discrete' and actions.shape[-1] == 1:
            actions = F.one_hot(actions.squeeze(-1).long(), num_classes=self.action_dim).float()
        # act encoder struct has ln already
        actions = self.action_encoder(actions)
        batch_size = actions.shape[0]
        if self.classifier_use_gru:
            hidden = torch.zeros(
                self.classifier_gru_num_layer,
                batch_size if not self.classifier_use_act_enc else batch_size * self.n_agent,
                self.n_embd
            ).to(self.device)
            # h_input: (batch_size, his_len + 1, n_embd) or (batch_size * agent_num, his_len + 1, n_embd)
            h_input = torch.cat([
                actions, self.end_of_token.repeat(batch_size, 1, 1)
            ], dim=1) if not self.classifier_use_act_enc else torch.cat([
                actions, self.end_of_token.repeat(batch_size, self.n_agent, 1, 1)
            ], dim=2).reshape(-1, self.classifier_gru_his_len + 1, self.n_embd)
            h_input = self.action_ln(h_input)
            logits, _ = self.inner_model(h_input, hidden)
            # logits: (batch_size * agent_num, n_embd)
            logits = logits[:, -1, :]
        else:
            actions = self.action_ln(actions)
            logits = self.inner_model(torch.cat([obs, actions], dim=-1)) \
                if not self.classifier_only_action else self.inner_model(actions)
        if self.classifier_use_act_enc:
            # logits: (batch_size, agent_num, n_embd)
            logits = logits.reshape(batch_size, self.n_agent, self.n_embd)
            logits = self.agent_encoder(logits)
        # add a output head after gru or encoder
        if self.classifier_use_gru or self.classifier_use_act_enc:
            logits = self.classifier_head(logits)

        return logits

    def get_loss(self, obs, action, tag):
        logits = self.forward(obs, action).reshape(-1, self.classifier_data_tag_num)
        tag = tag.reshape(-1)
        # print('----------------')
        # print('logits', logits.shape)
        # print('tag', tag.shape)

        return F.cross_entropy(input=logits, target=tag)

    def get_entropy_reward(self, obs, action):
        logits = F.softmax(self.forward(obs, action), dim=-1)

        return -torch.sum(torch.log(logits) * logits, dim=-1)


class AgentClassifierTrainer:
    def __init__(self, args, agent_classifier, device=torch.device("cpu")):
        # set basic info for trainer
        self.args = args
        self.agent_classifier = agent_classifier
        self.device = device
        # set hyper parameters
        # only need optimizer to train discriminator for gail
        self.optimizer = torch.optim.Adam(
            self.agent_classifier.parameters(), lr=self.args.lr,
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=self.args.lr_decay_step_size,
            gamma=args.lr_decay_gamma,
        )

    def update(self, obs, actions, tags):
        obs = obs.to(self.device)
        actions = actions.to(self.device)
        tags = tags.to(self.device)
        loss = self.agent_classifier.get_loss(obs, actions, tags)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        return loss.item()


def init_(m, gain=0.01, activate=False):
    if activate:
        gain = nn.init.calculate_gain('relu')
    return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=gain)


def init_disc(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        nn.init.zeros_(module.bias)





