import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import einops

class ModulatedIdentity(nn.Module):
    def __init__(self,modulator_shape,channel_wise_modulation=False):
        super(ModulatedIdentity, self).__init__()
        
        if not channel_wise_modulation:
            modulator_shape = (1,*modulator_shape)
        else:
            # one modulator value per channel
            modulator_shape = (1, modulator_shape[0], 1, 1)
            
        self.modulator = nn.Parameter(torch.ones(modulator_shape)) 
        #torch.nn.init.uniform_(self.modulator, 0.9, 1.1)
        self.act = nn.ReLU()
        self.modulation_training = False

    def forward(self, data, modulation, controller_params=None):
        #x = x + self.modulator*modulation
        if not self.modulation_training:
            return data
        return data * self.act(self.modulator*modulation)

    def get_modulator_params(self):
        return [self.modulator]
    
    def set_modulation_training(self, value):
        self.modulation_training = value


class ModulatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding=0, input_shape=(1,28,28), channel_wise_modulation=False,bias=True):
        super(ModulatedConv2d, self).__init__()
        
        self.conv = nn.Conv2d(in_channels,out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=bias)
        if not channel_wise_modulation:
            modulator_shape = self.conv(torch.zeros((input_shape)).unsqueeze(0)).shape
        else:
            # one modulator value per channel
            modulator_shape = (1, out_channels, 1, 1)
            
        self.modulator = nn.Parameter(torch.ones(modulator_shape)) 
        #torch.nn.init.uniform_(self.modulator, 0.9, 1.1)
        self.act = nn.ReLU()
        self.modulation_training = False

    def forward(self, data, modulation, controller_params=None):
        #x = x + self.modulator*modulation
        out = self.act(self.conv(data))
        if not self.modulation_training:
            return out
        return out * self.act(self.modulator*modulation)

    def get_modulator_params(self):
        return [self.modulator]
    
    def set_modulation_training(self, value):
        self.modulation_training = value


class ControlledModulatedConv2d(nn.Module):
    def __init__(self, in_channels, 
                 out_channels,
                 kernel_size=5,
                 stride=1,
                 padding=0,
                 channel_wise_modulation=False,
                 input_shape=None,
                 bias=True,
                 normal_bias_init=False,
                 activation = "nn.ReLU"):
        super(ControlledModulatedConv2d, self).__init__()
        
        self.conv = nn.Conv2d(in_channels,out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=bias)
        
        if normal_bias_init > 0 and bias:
            print("Normal bias init")
            torch.nn.init.normal_(self.conv.bias,0,normal_bias_init)

        if input_shape is not None:
            self.conv_out_shape = self.conv(torch.zeros((input_shape)).unsqueeze(0)).shape[1:]
        else:
            self.conv_out_shape = (out_channels, 1, 1)
        self.act = eval(activation)()
        self.modulation_training = False
    
    def set_modulation_training(self, value):
        self.modulation_training = value
    
    def forward(self, data, modulation, controller_params, num_tasks=None,steps=None):
        conv_out = self.act(self.conv(data))
        
        if not self.modulation_training:
            return conv_out
        if steps is not None:
            conv_out = conv_out.unsqueeze(0).repeat((steps,1,1,1,1))
            #conv_out = einops.rearrange(conv_out,'n b c h w -> (b n) c h w')
            if num_tasks is not None:
                conv_out = conv_out.unsqueeze(0).repeat((num_tasks,1,1,1,1,1))
                conv_out = einops.rearrange(conv_out,'t s b c h w -> (b s t) c h w')
            else:
                conv_out = einops.rearrange(conv_out,'s b c h w -> (b s) c h w')
        elif num_tasks is not None:
            conv_out = conv_out.unsqueeze(0).repeat((num_tasks,1,1,1,1))
            conv_out = einops.rearrange(conv_out,'n b c h w -> (b n) c h w')
        return conv_out*controller_params*modulation

    def get_modulator_params(self):
        return None


