import torch.nn as nn
import torch


class Expert(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        return self.net(x)


class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts):
        super(SparseMoE, self).__init__()
        self.router = None
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)
        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))
        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)
        return final_output


class Adapter(nn.Module):
    def __init__(self, layer, sim, memories, cfg):
        super().__init__()
        self.layer = layer
        self.sim = sim
        self.cfg = cfg
        self.memories = torch.tensor(self.sim.encode(memories, show_progress_bar=False), device='cuda')
        self.sparse_moe = SparseMoE(self.cfg.n_embed, self.cfg.n_experts).to('cuda')
        self.sparse_moe.router = self.router
        self.is_activated = True
        self.inputs = []

    def is_similar(self, query):
        query_emb = torch.tensor(self.sim.encode(query, show_progress_bar=False), device='cuda')
        scores = self.sim.similarity(query_emb, self.memories)
        output, _ = scores.topk(1, dim=-1)
        return (output[0][0] >= self.cfg.threshold).item()

    def router(self, x):
        assert x.size(0) == len(self.inputs)

        # for now bruteforce -> we will use faiss later
        input_emb = torch.tensor(self.sim.encode(self.inputs, show_progress_bar=False), device='cuda')
        split_size = len(self.memories) // self.cfg.n_experts + (1 if len(self.memories) % self.cfg.n_experts > 0 else 0)
        memories = torch.split(self.memories, split_size)

        routing_vec = []
        for memory in memories:
            scores = self.sim.similarity(input_emb, memory)
            router_output, _ = scores.topk(1, dim=-1)
            routing_vec.append(router_output)

        routing_vec = torch.stack(tuple(routing_vec), dim=2).squeeze(1)
        routing_vec[routing_vec >= self.cfg.threshold] = 1
        routing_vec[routing_vec < self.cfg.threshold] = 0
        routing_vec = routing_vec.unsqueeze(1).repeat(1, x.size(1), 1)
        _, indices = routing_vec.topk(self.cfg.top_k)
        return routing_vec, indices

    def forward(self, x):
        x = self.layer(x)
        if self.is_activated:
            x = x + self.sparse_moe(x)
        return x
