import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed.nn as dist_nn


class LiBroMLP(nn.Module):
    def __init__(
        self,
        num_features: int,
        depth: int,
        activation: nn.Module,
        weight_rank_ratio: float = 0.25,
    ) -> None:
        super(LiBroMLP, self).__init__()
        weight_rank = self.weight_rank_mapping(num_features, weight_rank_ratio)
        self.weight_rank = weight_rank
        weights = torch.randn(depth, num_features, weight_rank)
        weights = weights / num_features
        self.weights = nn.Parameter(weights)
        self.bias = nn.Parameter(torch.zeros(depth, num_features))

        self.register_buffer("identity", torch.eye(num_features))
        self.register_buffer("identity_half", torch.eye(weight_rank))
        # self.register_buffer(
        #     "eval_weights", torch.zeros(depth, num_features, num_features)
        # )

        self.act = activation
        self.depth = depth
        self.width = num_features
        self.scale = depth**-0.5
        self.gamma = nn.Parameter(torch.ones(depth, num_features, 1))
        residual_identity = torch.zeros(num_features, weight_rank)
        residual_identity[:weight_rank, :weight_rank] = torch.eye(weight_rank)
        self.register_buffer("residual_identity", residual_identity)
        flag = torch.distributed.is_initialized()
        self.flag = flag and depth % torch.distributed.get_world_size() == 0

    def get_weights(self):
        if self.flag:
            rank = torch.distributed.get_rank()
            world = torch.distributed.get_world_size()
            num_per_gpu = self.depth // world
            index = range(rank * num_per_gpu, (rank + 1) * num_per_gpu)
            _V = (
                self.residual_identity.data
                + self.weights[index] * self.gamma[index] * self.scale
            )
            Sigma = _V.transpose(1, 2) @ _V
            eps = Sigma.diagonal(dim1=1, dim2=2).mean(-1).mul(1e-5)
            Sigma = Sigma + eps.view(*eps.shape, 1, 1) * self.identity_half.unsqueeze(0)
            _W = self.identity - 2 * _V @ torch.linalg.solve(Sigma, _V.transpose(1, 2))
            _W = _W.contiguous()
            W = dist_nn.functional.all_gather(_W)
            W = torch.cat(W, dim=0)
            return W
        V = self.residual_identity + self.weights * self.gamma * self.scale
        Sigma = V.transpose(1, 2) @ V
        eps = Sigma.diagonal(dim1=1, dim2=2).mean(-1).mul(1e-5)
        Sigma = Sigma + eps.view(*eps.shape, 1, 1) * self.identity_half.unsqueeze(0)
        W = self.identity - 2 * V @ torch.linalg.solve(Sigma, V.transpose(1, 2))
        return W

    def forward(self, x):
        if self.training:
            weights = self.get_weights()
        else:
            weights = self.eval_weights
        for weight, bias in zip(weights, self.bias):
            x = F.linear(x, weight, bias)
            x = self.act(x)
        return x

    def train(self, mode=True):
        self.training = mode
        if mode is False:
            with torch.no_grad():
                weights = self.get_weights().detach()
                # self.eval_weights += weights - self.eval_weights
                self.eval_weights = weights
        else:
            if hasattr(self, "eval_weights"):
                del self.eval_weights
        return self

    def lipschitz(self):
        return 1.0

    def extra_repr(self) -> str:
        return f"BRO: depth={self.depth}, weight_rank/width = {self.weight_rank}/{self.width}, distributed={self.flag}"

    @staticmethod
    def weight_rank_mapping(width, weight_rank_ratio):
        if weight_rank_ratio == 1.0:
            return width - 1
        elif weight_rank_ratio == 0.0:
            return 1
        else:
            weight_rank = int(width * weight_rank_ratio)
        assert weight_rank > 0 and weight_rank <= width
        return weight_rank
