import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import einops

class QuadrantModulatedEncoder(nn.Module):
    
    def __init__(self,size_in,size_out):
        super(QuadrantModulatedEncoder, self).__init__()
        
        assert size_in%4 == 0, size_out%4 == 0 
        self.quadrant_weights = nn.ModuleList([nn.Linear(size_in//4,size_out//4) for i in range(4)])
        self.modulator = nn.Parameter(torch.ones(size_out),requires_grad=True)
        self.act = nn.ReLU()
        self.modulation_training = False
        
    def forward(self, data, modulation):
        quads = []
        for q in range(4):
            lin = self.quadrant_weights[q]
            quads.append(lin(data[:,14*(q%2):14*(q%2)+14, 14*(q//2): 14*(q//2)+14].reshape(-1,14*14)))
        
        quads = torch.cat(quads,1)
        if not self.modulation_training:
            return torch.exp(quads)
        else:
            return  torch.exp(quads + self.modulator*modulation)

    def set_modulation_training(self, value):
        self.modulation_training = value

class ModulatedLinear(nn.Module):
    
    def __init__(self,size_in,size_out):
        super(ModulatedLinear, self).__init__()
        self.linear = nn.Linear(size_in,size_out) 
        self.modulator = nn.Parameter(torch.ones(size_out),requires_grad=True)
        self.act = nn.ReLU()
        self.modulation_training = False
        
    def set_modulation_training(self, value):
        self.modulation_training = value

    def forward(self, data, modulation):
        out = self.act(self.linear(data))
        if not self.modulation_training:
            return out
        return out * self.act(self.modulator*modulation)#torch.exp(self.linear(data) + self.modulator*modulation)
              

def normalize(x):
    return (x - x.mean(1).unsqueeze(1)) / x.std(1).unsqueeze(1)

class ModulatedDecoder(nn.Module):
    
    def __init__(self,size_in,size_out,normal_bias_init=False,activation="nn.ReLU",use_bias=True):
        super(ModulatedDecoder, self).__init__()
        self.linear = nn.Linear(size_in,size_out, bias=use_bias)
        if normal_bias_init > 0 and use_bias:
            print("Normal bias init")
            torch.nn.init.normal_(self.linear.bias,0,normal_bias_init)
        self.size_out = size_out
        self.activation = eval(activation)()
        self.gain = 1
        self.gradient_gain = False
        self.square_gain = False
        
    def set_gain(self,gain):
        self.gain = gain
        
    def forward(self, data,tasks=None):
        if self.square_gain:
            out = self.activation(self.linear(data))
            out = out**2
        elif not isinstance(self.gain,(int, float)) and not self.gradient_gain:
            out_lin  = self.activation(self.linear(data))
            if self.gain.shape[0] != out_lin.shape[0]:
                out_lin = einops.rearrange(out_lin,'(b t) c -> b t c',t=self.gain.shape[0])
            out = self.gain*out_lin
            if len(out.shape) == 3: 
                out = einops.rearrange(out,'b t c -> (b t) c')
        else:
            out = self.gain*self.activation(self.linear(data))
        return out
        