import torch
import torch.nn as nn
from .layers.modulated_conv_layers import  ControlledModulatedConv2d
from .layers.modulated_linear_layers import  ModulatedDecoder
import einops
import torchvision


class ModulatedSubnet(nn.Module):
    def __init__(self,
                num_layers = 2,
                use_bias=True,
                conv_activation="nn.ReLU",
                ):

        super(ModulatedSubnet, self).__init__()
        self.mod_conv_idx = 0 
        self.modulator_parameters = None
        conv_layers = [ControlledModulatedConv2d(512, 512, kernel_size=3, stride=1, padding=1, input_shape=None, channel_wise_modulation=True,bias=use_bias,activation=conv_activation)]
        conv_layers.extend([ControlledModulatedConv2d(512, 512, kernel_size=3, stride=1, padding=1, input_shape=None, channel_wise_modulation=True,bias=use_bias,activation=conv_activation) for i in range(num_layers)])
        self.conv_layers = nn.ModuleList(conv_layers)
    
    def get_modulated_layer(self):
        return self.conv_layers[self.mod_conv_idx]
    
    def forward_until_modulator(self, x):
        for i in range(self.mod_conv_idx):
            x = self.conv_layers[i](x,modulation=1,controller_params=None)
        return x

    def forward_from_modulator(self, x, modulation, controller_params=None, steps=None,num_tasks=None):

        for i in range(self.mod_conv_idx,len(self.conv_layers)):
            x = self.conv_layers[i](x,modulation=modulation,controller_params=controller_params,num_tasks=num_tasks,steps=steps)
        return x

    def forward(self, x, modulation, controller_params=None, steps=None):
        x = self.forward_until_modulator(x)
        x = self.forward_from_modulator(x,modulation,controller_params,steps)
        return x

class MultiTaskCNN(nn.Module):
    def __init__(self,
                 input_shape,
                 task_groups,
                 use_classifier=False,
                 trainable_weights=False,
                 image_size=64,
                 num_possible_conv=2,
                 pretrained=True,
                 dec_activation="nn.ReLU",
                 conv_activation="nn.ReLU",
                 output_sizes=None):

        super(MultiTaskCNN,self).__init__()
        
        self.tasks_groups = task_groups
        self.resnet18 = torchvision.models.resnet18(pretrained=pretrained)
        self.num_possible_conv = num_possible_conv
        self.modulated_net = ModulatedSubnet(num_layers=num_possible_conv,
                                             use_bias=True,
                                             conv_activation=conv_activation)
        self.num_tasks = sum([len(group) for group in self.tasks_groups])

        dec_input =  3072
        self.num_features_decoder = 2048
        self.decoder = ModulatedDecoder(int(dec_input),self.num_features_decoder,activation=dec_activation)  
        
        self.use_classifier = use_classifier
        if use_classifier:
            classifiers = []
            if output_sizes is not None:
                for i in range(len(task_groups)):
                    classifiers.append(nn.Linear(self.num_features_decoder,output_sizes[i]))
            else:
                for i in range(len(task_groups)):
                    classifiers.append(nn.Linear(self.num_features_decoder,len(task_groups[i])))

            self.classifiers = nn.ModuleList(classifiers)
            
    def resnet_forward(self, x):
        x = self.resnet18.conv1(x)
        x = self.resnet18.bn1(x)
        x = self.resnet18.relu(x)
        x = self.resnet18.maxpool(x)

        x = self.resnet18.layer1(x)
        x = self.resnet18.layer2(x)
        x = self.resnet18.layer3(x)
        x = self.resnet18.layer4(x)
        return x

    def forward(self, x, modulations=None, controller_params=None,tasks=None):
        if controller_params is None:
            x = self.encode_until_modulator(x)
            
            x = self.decode(x,1)
            x_dec = x.clone()
            if self.use_classifier:
                x = [self.classifiers[i](x) for i in range(len(self.classifiers))]
            return x, x_dec

        else :
            batch_size = x.shape[0]

            if controller_params.shape[0] == self.num_tasks:
                controller_params = controller_params.unsqueeze(0).repeat((batch_size,1,1,1,1))#One controller per task for each image
                controller_params = einops.rearrange(controller_params,'b n c x y -> (b n) c x y')
            else:
                assert controller_params.shape[0] == batch_size*self.num_tasks

            with torch.no_grad():
                x = self.encode_until_modulator(x)

            x_dec = self.decode(x,modulations,controller_params,num_tasks=self.num_tasks,steps=None)
            x_dec = einops.rearrange(x_dec,'(b n) c -> b n c',b=batch_size)
            x = self.classify_tasks(x_dec,)
            return x, x_dec
            
    def set_layer_to_modulate(self, layer_to_modulate):
        if layer_to_modulate is None:
            self.modulator_parameters = None
        
    def get_layer_to_modulate(self):
        return self.modulated_net.get_modulated_layer()
    
    def get_modulator_params(self):
        return self.get_layer_to_modulate().get_modulator_params()
        
    def set_modulation_training(self, training):
        self.get_layer_to_modulate().set_modulation_training(training)
    

    def encode_until_modulator(self, x):
        x = self.resnet_forward(x)        
        x = self.modulated_net.forward_until_modulator(x) 
        return x

    def decode(self, x, modulation, controller_params=None, num_tasks=None,steps=None):
        x = self.modulated_net.forward_from_modulator(x,modulation,controller_params,num_tasks=num_tasks,steps=steps)
        x = torch.flatten(x, 1)
        decode = self.decoder(x)
        return decode

    def classify_tasks(self, x_dec):
        last_task = 0
        x = []
        for i in range(len(self.tasks_groups)):
            tasks_of_group =  self.tasks_groups[i]  
            x_dec_group = x_dec[:,list(range(last_task,last_task+len(tasks_of_group)))]
            last_task = last_task+len(tasks_of_group)
            output = torch.diagonal(self.classifiers[i](x_dec_group),dim1=1,dim2=2)
            x.append(output)
        return x
        
    
    def get_pretrain_parameters(self):
        return self.parameters()
    
    def set_gain(self, gain):
        self.decoder.gain = gain


    def get_outputs_weights(self):
        return torch.vstack(([classif.weight for classif in self.classifiers]))
