import os, sys
from .naive_gate import NaiveGate

import os
import sys
import torch
import torch.nn.functional as F
from .utils import limit_by_capacity
import fmoe_cuda
from fmoe.functions import count_by_gate

nw_per_node = 8
try:
    nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE'])
except Exception:
    pass

class FasterGate(NaiveGate):
    def __init__(self, d_model, n_expert, world_size, node_rank):
        super().__init__(d_model, n_expert, world_size, top_k=2)
        self.ne_per_node = nw_per_node * n_expert
        self.ogn_ratio = .14
        try:
            self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION'])
        except Exception:
            pass
        self.node_rank = node_rank

        mask = [1] * world_size * n_expert
        for i in range(n_expert * world_size):
            if i // self.ne_per_node == self.node_rank:
                mask[i] = 0
        self.mask = torch.Tensor(mask).bool()
        self.policy_fn = None
        print('node rank {} mask {}'.format(node_rank, mask))

    def forward(self, inp):
        if self.mask.device != inp.device:
            self.mask = self.mask.to(inp.device)

        gate_score = self.gate(inp)
        lim_mask = self.mask

        top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1)
        S = gate_score.shape[0]
        top_k = 2

        with torch.no_grad():
            top1_idx = top2_idx.view((-1, top_k))[:, 0]
            top1_val = top2_val.view((-1, top_k))[:, 0]
        c_e = torch.scatter_add(
                torch.zeros(self.tot_expert, device=top1_idx.device),
                0,
                top1_idx,
                torch.ones_like(top1_idx, dtype=torch.float),
                ) / S
        m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
        loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
        self.set_loss(loss)

        with torch.no_grad():
            if self.policy_fn is None:
                stored_models = torch.zeros(self.num_expert * self.world_size,
                        dtype=torch.bool)
            else:
                # TODO: Fix this after expert shadowing is ported
                _, lec, aec, gec, agec = count_by_gate(top2_idx, 
                        self.num_expert, self.world_size, require_pos=False)
                stored_models = self.policy_fn(aec, agec,
                        self.num_expert, self.world_size, inp.shape[-1], True)
            lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device)

            ogn_mask = lim_mask[top1_idx]
            ogn_thres = int(inp.shape[0] * self.ogn_ratio)

        if ogn_mask.sum().item() < ogn_thres:
            topk_val, topk_idx = torch.topk(gate_score, k=self.top_k)
            topk_val = F.softmax(topk_val, dim=-1)
            return topk_idx, topk_val

        with torch.no_grad():
            top1_val[~ogn_mask] = float('-inf')
            _, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres)
            cand = gate_score.clone()
            cand[:, lim_mask] = float('-inf')
            _, topk_idx = torch.topk(cand, k=self.top_k)
            topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn]

        idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2)
        topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)

        topk_val = F.softmax(topk_val, dim=-1)

        return topk_idx, topk_val

def gen_faster_gate(rank):
    def _gen(d_model, n_expert, world_size, top_k=2):
        assert top_k == 2
        return FasterGate(d_model, n_expert, world_size, rank // nw_per_node)
    return _gen