import sys
import torch
import transformers
import datasets
import os

utils_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../util'))
folder_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))

sys.path.append(utils_path)
sys.path.append(folder_path)

from magicattr import  setattr_by_module_name
from torch import nn
from torch.nn.functional import linear


#############################################
##### Compute the Features ##################
#############################################

class FeatureLayer(nn.Module):

    weight: torch.Tensor

    def __init__(self, layer, name , device=None, dtype=None,) -> None:
        factory_kwargs = {'device': device, 'name': name, 'dtype': dtype}
        super(FeatureLayer,self).__init__()

        # different float values needed
        # self.weight = layer.weight.clone().to(torch.float16)
        # if(layer.bias is not None):
        #     self.bias = layer.bias.clone().to(torch.float16)
        # else:
        #     self.bias = layer.bias
        self.weight = layer.weight
        self.bias = layer.bias

        self.name = name
        self.device = device
        self.kernelShape = layer.weight.shape[-1]


        
        self.computation = 0
        # first calculate the mean, 
        self.correlationMean = torch.zeros(size= (1,self.kernelShape) ).to(device) 
        # then calculate the Correlation matrix
        self.correlationM = torch.zeros(size= (self.kernelShape,self.kernelShape) ).to(device) 

        self.computeMean = True
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        self.computation += 1

        #upcastTensor = input.clone().to(torch.float32)

        if(self.computeMean):
            self.correlationMean += ( torch.mean(input.view(-1,self.kernelShape) , dim=0   )- self.correlationMean ) / self.computation
        else:
            self.correlationM += ( torch.matmul(torch.transpose(input.view(-1,self.kernelShape)-self.correlationMean ,0,1) ,
                                                input.view(-1,self.kernelShape)-self.correlationMean  ) - self.correlationM ) / self.computation
        
        x = linear(input, self.weight, self.bias)
        
        return x

def replaceWithFMmodules( model, device=None , layers = None):
    '''
    If a list of numbers is given only these blocks are replaced
    '''

    if(device is None):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    modules = [(name, module) for name, module in model.named_modules() if isinstance(module, nn.Linear)]
    replacedM = []

    for name, module in modules:
    
        if( len(module.weight.shape) == 2 and (not "pooler" in name) and name != "head.module.layers.0" ):

            if(layers is None or any([x for x in layers if x in name ])):
                replacedM.append(name)
                # replace by my own new module   
                saving_layer = FeatureLayer(name=name,layer=module, device=device)
        
                # exchange the module layer -- function defined below
                setattr_by_module_name(model, name, saving_layer)

    return replacedM

def setComputationMode(model , mode):
    '''
    mode="mean" the mean should be computed (always the first step)
    mode="FM" the FM with substracted mean should be computed
    '''
    computeMean = True if mode=="mean" else False
    
    modules = [(name, module) for name, module in model.named_modules() if isinstance(module, FeatureLayer)]

    for name, module in modules:
        module.computeMean = computeMean
