import torch
import torch.nn as nn
from .layers.modulated_conv_layers import ModulatedConv2d, ControlledModulatedConv2d
from .layers.modulated_linear_layers import  ModulatedDecoder
import einops
import torchvision

class ModulatedSubnet(nn.Module):
    def __init__(self,
                num_layers = 3,
                input_features = 512,
                channel_size = 512,
                use_bias=True,
                conv_activation="nn.ReLU",
                conv_normal_bias_init=False,
                use_residual_connection=False,
                ):

        super(ModulatedSubnet, self).__init__()
        self.mod_conv_idx = 0 
        self.modulator_parameters = None
        self.use_residual_connection = use_residual_connection
        conv_layers = [ControlledModulatedConv2d(input_features, channel_size, kernel_size=3, stride=1, padding=1, channel_wise_modulation=True,normal_bias_init=conv_normal_bias_init,bias=use_bias,activation=conv_activation)]
        conv_layers.extend([ControlledModulatedConv2d(channel_size, channel_size, kernel_size=3, stride=1, padding=1, channel_wise_modulation=True,normal_bias_init=conv_normal_bias_init,bias=use_bias,activation=conv_activation) for i in range(num_layers-1)])
        self.conv_layers = nn.ModuleList(conv_layers)
                
    
    def get_modulated_layer(self):
        return self.conv_layers[self.mod_conv_idx]
    
    def forward_until_modulator(self, x):
        x_res = x.clone()
        for i in range(self.mod_conv_idx):
            x = self.conv_layers[i](x,modulation=1,controller_params=None)
            if self.use_residual_connection:
                if x_res is not None:
                    x = x + x_res
                    x_res = x.clone()
                else:
                    x_res = x.clone()
        return x

    def forward_from_modulator(self, x, modulation, controller_params=None, steps=None):
        x_res = x.clone()
        if steps is not None:
            x_res = x_res.unsqueeze(0).repeat((steps,1,1,1,1))
            x_res = einops.rearrange(x_res,' s b c h w -> (b s) c h w')
        controller_res_is_done = False
        for i in range(self.mod_conv_idx,len(self.conv_layers)):
            x = self.conv_layers[i](x,modulation=modulation,controller_params=controller_params,num_tasks=None,steps=steps)
            if (self.use_residual_connection and not controller_res_is_done):
                controller_res_is_done = True
                x = x + x_res
                x_res = x.clone()
            elif (controller_res_is_done and self.use_residual_connection):
                x = x + x_res
                x_res = x.clone()
                
        return x

    def forward(self, x, modulation, controller_params=None, steps=None):
        x = self.forward_until_modulator(x)
        x = self.forward_from_modulator(x,modulation,controller_params,steps)
        return x
        
class MultiTaskCNN(nn.Module):
    def __init__(self,
                 input_shape,
                 task_groups,
                 use_classifier=False,
                 trainable_weights=False,
                 image_size=64,
                 num_possible_conv=3,
                 dec_activation="nn.ReLU",
                 conv_activation="nn.ReLU",
                 conv_normal_bias_init=False,
                 dec_normal_bias_init=False,
                 coarse_labels=False,
                 pretrained=False,
                 number_of_resnet_to_cut=0,
                 use_possible_bias=True,
                 use_dec_bias=True,
                 use_residual_connection=False):
        
        super(MultiTaskCNN,self).__init__()
        
        self.task_groups = task_groups
        self.number_of_resnet_to_cut = number_of_resnet_to_cut
        self.resnet = torchvision.models.resnet18(pretrained=pretrained) #torch.hub.load('facebookresearch/vicreg:main', 'resnet50')
        self.num_possible_conv = num_possible_conv
        
        first_conv_channel = 512//2**(number_of_resnet_to_cut)
        self.modulated_net = ModulatedSubnet( num_layers=self.num_possible_conv,
                                             input_features=first_conv_channel,
                                             channel_size=first_conv_channel,
                                             use_bias=use_possible_bias,
                                             conv_activation=conv_activation,
                                             conv_normal_bias_init=conv_normal_bias_init,
                                             use_residual_connection=use_residual_connection)
        
        self.num_tasks = len(self.task_groups) 

        dec_input = int(1024* 2**(number_of_resnet_to_cut + 1)) 
        self.num_features_decoder = 2048
        self.decoder = ModulatedDecoder(dec_input,self.num_features_decoder,normal_bias_init=dec_normal_bias_init,activation=dec_activation,use_bias=use_dec_bias)
        
        self.use_classifier = use_classifier
        self.classifiers = nn.Linear(self.num_features_decoder,len(self.task_groups) if  coarse_labels else sum([len(tg) for tg in self.task_groups]))
            
    def resnet_forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)

        x = self.resnet.layer1(x)
        if self.number_of_resnet_to_cut == 3:
            return x
        x = self.resnet.layer2(x)
        if self.number_of_resnet_to_cut == 2:
            return x
        x = self.resnet.layer3(x)
        if self.number_of_resnet_to_cut == 1:
            return x
        x = self.resnet.layer4(x)
        return x

    def forward(self, x, modulations=None, controller_params=None,tasks=None):
        if controller_params is None:
            x = self.encode_until_modulator(x)
            x_dec = self.decode(x,1)
            x = self.classifiers(x_dec)
            return x, x_dec

        else :
            batch_size = x.shape[0]
            assert controller_params.shape[0] == batch_size
            with torch.no_grad():
                x = self.encode_until_modulator(x)
            x_dec = self.decode(x,modulations,controller_params,steps=None)
            
            x = self.classifiers(x_dec )
            return x, x_dec
            
    def set_layer_to_modulate(self, layer_to_modulate):
        if layer_to_modulate is None:
            self.modulator_parameters = None
        
    def get_layer_to_modulate(self):
        return self.modulated_net.get_modulated_layer()
    

    def get_modulator_params(self):
        return self.get_layer_to_modulate().get_modulator_params()
        
    def set_modulation_training(self, training):
        self.get_layer_to_modulate().set_modulation_training(training)
                
    def encode_until_modulator(self, x):
        
        x = self.resnet_forward(x)   
        x = self.modulated_net.forward_until_modulator(x)     
        return x

    def decode(self, x, modulation, controller_params=None, steps=None):
        x = self.modulated_net.forward_from_modulator(x,modulation,controller_params,steps)
        x = torch.flatten(x, 1)
        decode = self.decoder(x)
        return decode
            
    def get_pretrain_parameters(self):
        return self.parameters()

    def get_backbone_parameters(self):
        return self.resnet.parameters()

    def get_added_parameters(self):
        return list(self.modulated_net.parameters()) + list(self.decoder.parameters()) + list(self.classifiers.parameters())

    def set_gain(self, gain):
        self.decoder.gain = gain

    def get_outputs_weights(self):
        return self.classifiers.weight
