import torch
import torch.nn as nn
import torch.jit as jit
import torch.nn.functional as F
from torch import Tensor



class Linear(nn.Linear):
    
    def __init__(self, in_features, out_features, gain=1.0, init_bias=0.0, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        nn.init.orthogonal_(self.weight, gain=gain)
        self.bias.data.fill_(init_bias)


class Q_network_RNN(nn.Module):

    def __init__(self, ob_dim, ac_dim, h_dim):
        super().__init__()
        self.fc1 = Linear(ob_dim, h_dim)
        self.hidden = Linear(h_dim, h_dim)
        self.fc2 = Linear(h_dim, ac_dim)

    def forward(self, inputs) -> torch.Tensor:
        h = F.relu(self.fc1(inputs))
        h = F.relu(self.hidden(h))
        Q = self.fc2.forward(h)
        return Q
    

class QMIX_Net(nn.Module):

    def __init__(self, n_agents, st_dim, h_dim, activation="elu"):
        super().__init__()
        self.h_dim = h_dim
        self.use_abs = True
        if activation == "elu":
            self.act_fn = nn.ELU()
        elif activation == "relu":
            self.act_fn = nn.ReLU()
        elif activation == "tanh":
            self.act_fn = nn.Tanh()
        elif activation == "sigmoid":
            self.act_fn = nn.Sigmoid()
        # self.hyper_w1 = Linear(st_dim, n_agents * h_dim)
        # self.hyper_w2 = Linear(st_dim, h_dim)
        # self.hyper_b1 = Linear(st_dim, h_dim)
        # self.hyper_b2 = nn.Sequential(Linear(st_dim, h_dim), nn.ReLU(), Linear(h_dim, 1))

        self.fc_w = Linear(st_dim, n_agents)
        self.fc_b = Linear(st_dim, 1)

    def forward(self, q: Tensor, s: Tensor) -> Tensor:
        # print("q:", q.shape)
        # print("s:", s.shape)
        # batch_size, _, st_dim = s.shape
        # n_agents = q.size(-1)
        # q = q.reshape(-1, 1, n_agents)
        # s = s.reshape(-1, st_dim)
        # w1 = self.hyper_w1.forward(s)
        # if self.use_abs:
        #     w1 = torch.abs(w1)
        # b1 = self.hyper_b1.forward(s)
        # w1 = w1.view(-1, n_agents, self.h_dim)
        # b1 = b1.view(-1, 1, self.h_dim)
        # q_hidden = self.act_fn(torch.bmm(q, w1) + b1)
        # w2 = self.hyper_w2.forward(s)
        # if self.use_abs:
        #     w2 = torch.abs(w2)
        # b2 = self.hyper_b2.forward(s)
        # w2 = w2.view(-1, self.h_dim, 1)
        # b2 = b2.view(-1, 1, 1)
        # q_total = torch.bmm(q_hidden, w2) + b2
        # q_total = q_total.view(batch_size, -1, 1)
        # print("q_total:", q_total.shape)
        # return q_total

        w = self.fc_w.forward(s).abs()
        b = self.fc_b.forward(s)

        q_total = (q * w).sum(-1, True) + b
        return q_total



class DQN(nn.Module):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu"):
        super().__init__()
        self.ob_dim = ob_dim
        self.st_dim = st_dim
        self.ac_dim = ac_dim
        self.n_agents = n_agents
        self.h_dim = h_dim

        self.eval_Q_net = Q_network_RNN(ob_dim, ac_dim, h_dim)
        self.target_Q_net = Q_network_RNN(ob_dim, ac_dim, h_dim).eval()
        self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())
        self.eval_mix_net = QMIX_Net(n_agents, st_dim, h_dim, activation)
        self.target_mix_net = QMIX_Net(n_agents, st_dim, h_dim, activation).eval()
        self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())


class IPL_IQL(DQN):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu", value_expert=False, chi_expert=False):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim, activation)
        self.value_expert = value_expert
        self.chi_expert = chi_expert
    
    @torch.compile(fullgraph=True)
    def compute_all_rewards(self, mb_obs: Tensor, mb_states: Tensor, mb_avails: Tensor, mb_actions: Tensor, mb_dones: Tensor, mb_actives: Tensor, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_states = mb_states.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_dones = mb_dones.to(device).float()
        mb_actives = mb_actives.to(device)

        obs = mb_obs[:, :-1]
        next_obs = mb_obs[:, 1:]
        states = mb_states[:, :-1]
        next_states = mb_states[:, 1:]
        avails = mb_avails[:, :-1]
        next_avails = mb_avails[:, 1:]

        curr_q = self.eval_Q_net.forward(obs)
        q_evals_local = torch.gather(curr_q, -1, mb_actions).squeeze(-1)
        q_evals_global = self.eval_mix_net.forward(q_evals_local, states).squeeze(-1)

        with torch.no_grad():
            next_q = self.target_Q_net.forward(next_obs)
            a_argmax_target = torch.argmax(next_q + next_avails, -1, True)
            v_targets_local = torch.gather(next_q, -1, a_argmax_target).squeeze(-1)
            v_targets_global = self.target_mix_net.forward(v_targets_local, next_states).squeeze(-1)
        
        global_rewards = q_evals_global - gamma * (1 - mb_dones) * v_targets_global
        global_rewards[~mb_actives] = 0.0
        global_rewards = global_rewards.sum(-1)
        return global_rewards

    @torch.compile(fullgraph=True)
    def predict_rewards(self, obs: Tensor, states: Tensor, avails: Tensor, actions: Tensor, next_obs: Tensor, next_states: Tensor, next_avails: Tensor, dones: Tensor, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        obs = obs.to(device)
        states = states.to(device)
        avails = avails.to(device).log()
        actions = actions.to(device).unsqueeze(-1)

        next_obs = next_obs.to(device)
        next_states = next_states.to(device)
        next_avails = next_avails.to(device).log()

        dones = dones.to(device).float()
        
        curr_q = self.eval_Q_net.forward(obs)
        q_evals_local = torch.gather(curr_q, -1, actions).squeeze(-1)
        q_evals_global = self.eval_mix_net.forward(q_evals_local, states).squeeze(-1)

        with torch.no_grad():
            next_q = self.target_Q_net.forward(next_obs)
            a_argmax_target = torch.argmax(next_q + next_avails, -1, True)
            v_targets_local = torch.gather(next_q, -1, a_argmax_target).squeeze(-1)
            v_targets_global = self.target_mix_net.forward(v_targets_local, next_states).squeeze(-1)

        dones_local = dones.unsqueeze(-1)

        # R = Q(s, a) - gamma * (1 - done) * V(s')
        
        rewards_local = q_evals_local - gamma * (1 - dones_local) * v_targets_local
        rewards_global = q_evals_global - gamma * (1 - dones) * v_targets_global

        weights = self.eval_mix_net.fc_w.forward(states).abs()
       
        return rewards_local, rewards_global, weights