import torch
import torch.nn as nn

from common.weight_space import (
    MoENetworkSpec,
    MoEWeightSpaceFeatures,
    LinearWeightSpaceFeatures,
    NetworkSpec,
)
from layers.layer_utils import shape_wsfeat_symmetry


class TupleOp(nn.Module):
    def __init__(self, op):
        super().__init__()
        self.op = op

    def forward(self, wsfeat: LinearWeightSpaceFeatures) -> LinearWeightSpaceFeatures:
        out_weights = [self.op(w) for w in wsfeat.weights]
        out_bias = [self.op(b) for b in wsfeat.biases]
        return LinearWeightSpaceFeatures(out_weights, out_bias)

    def __repr__(self):
        return f"TupleOp({self.op})"

class FlattenWeights(nn.Module):
    def __init__(self, network_spec):
        super().__init__()
        self.network_spec = network_spec

    def forward(self, wsfeat):
        wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec)
        outs = []
        for i in range(len(self.network_spec)):
            w, b = wsfeat[i]
            outs.append(torch.flatten(w, start_dim=2).transpose(1, 2))
            outs.append(b.transpose(1, 2))
        return torch.cat(outs, dim=1)  # (B, N, C)


class TupleOpMoE(nn.Module):
    def __init__(self, op, masked_features=None):
        super().__init__()
        self.op = op
        self.mask_features = masked_features if masked_features else []

    def forward(self, wsfeat: MoEWeightSpaceFeatures) -> MoEWeightSpaceFeatures:
        keys = ["W_q", "W_k", "W_v", "W_o", "W_G", "W_A", "W_B", "b_G", "b_A", "b_B"]
        out_dict = {}
        for key in keys:
            if key in self.mask_features:
                out_dict[key] = getattr(wsfeat, key)
            else:
                out_dict[key] = [self.op(w) for w in getattr(wsfeat, key)]

        return MoEWeightSpaceFeatures(**out_dict)

    def __repr__(self):
        return f"TupleOpMoE({self.op})"


class FlattenWeightsMoE(nn.Module):
    def __init__(self, encoder_weight_spec: MoENetworkSpec):
        super().__init__()
        self.encoder_weight_spec = encoder_weight_spec

    def forward(self, wsfeat):
        out = []
        L = len(wsfeat)  # Number of layers

        for i in range(L):
            W_q, W_k, W_v, W_o, W_G, W_A, W_B, b_G, b_A, b_B = wsfeat[i]
            out.append(torch.flatten(W_k, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_q, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_v, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_o, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_G, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_A, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(W_B, start_dim=2).transpose(1, 2))
            out.append(b_G.transpose(1, 2))
            out.append(torch.flatten(b_A, start_dim=2).transpose(1, 2))
            out.append(torch.flatten(b_B, start_dim=2).transpose(1, 2))
        return torch.cat(out, dim=1)  # (B, N, C)

