#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

from .layers import LoRALayer 
from typing import Optional, List 

class Bayes(nn.Module):
    def __init__(self, dimension, noise_rate, noise_type="Bernoulli"):
        super(Bayes, self).__init__()
        self.dimension = dimension
        self.noise_rate = noise_rate
        self.noise_type = noise_type
        if self.noise_type == "Gaussian":
            self.mean = 1.0
            self.std = math.sqrt(self.noise_rate/(1.0-self.noise_rate))

    def get_dimension(self):
        return self.dimension

    def forward(self, inputs):
        n, r = inputs.shape
        assert n>r
        if self.noise_type=="Bernoulli":
            v = torch.bernoulli(torch.full((n, 1), self.noise_rate)).squeeze()
        elif self.noise_type=="Gaussian":
            v = torch.randn(n, requires_grad=False)*self.std + self.mean
        else:
            raise ValueError(f'Noise type not found')
        v_expanded = v.view(-1, 1).expand(n, r).to("cuda")
        return v_expanded * inputs


class BayesLinear(nn.Linear, LoRALayer):
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        noise_rate: float = 0.,
        noise_type: str = 'Bernoulli',
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        self.noise_rate = noise_rate
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))

            self.bayes_A = Bayes(dimension=in_features, noise_rate=noise_rate, noise_type=noise_type)
            self.bayes_B = Bayes(dimension=out_features, noise_rate=noise_rate, noise_type=noise_type)

            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        # if self.merge_weights and self.merged:
        #     # Make sure that the weights are not merged
        #     if self.r > 0:
        #         self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
        #     self.merged = False
    
    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.eval(self)
        # if self.merge_weights and not self.merged:
        #     # Merge the weights and mark it
        #     if self.r > 0:
        #         self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
        #     self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                lora_A = self.bayes_A(self.lora_A.T).T
                lora_B = self.bayes_B(self.lora_B)
                result += (self.lora_dropout(x) @ lora_A.T @ lora_B.T) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)

