import torch
import torch.nn as nn
import einops

class MultiTaskModulatedNet(nn.Module):
    def __init__(self, 
                 network,
                 controller=None,
                 device="cpu",
                 steps=100,
                 use_modulations_in_gain=True,
                 only_train_added_params=False,
                 only_retrain_last=False,
                 normalize_gain_minmax_after=False,
                 comod_backprop=False,
                 num_tasks=1,
                 retrain_last_layer=False,
                 retrain_everything=False,
                 is_comodulation=False,
                 is_attention=False,
                 modulation_noise_mean=1,
                 modulation_noise_std=0.1,
                 compute_gain_once_with_train_set=False):
        super(MultiTaskModulatedNet, self).__init__()
        self.network = network        
        self.controller = controller
        self.device = device
        assert not (is_comodulation and is_attention), "Cannot be both comodulation and attention"

        self.is_attention = is_attention
        self.is_comodulation = is_comodulation
        self.backprop_mod = torch.ones((1,)).to(device) #("backprop_mod", .to(device))
        self.steps = steps
        self.gain = 1
        self.gains_over_time = []
        self.num_tasks = num_tasks
        self.modulation_for_gain = torch.normal(modulation_noise_mean,modulation_noise_std,(self.steps,1)).unsqueeze(-1).unsqueeze(-1).to(device).abs()
        if self.is_comodulation:
            self.gain = torch.ones((1, 1, num_tasks, network.num_features_decoder)).to(device)
            
        self.retrain_last = retrain_last_layer
        self.retrain_everything = retrain_everything
        self.use_modulations_in_gain = use_modulations_in_gain
        self.comod_backprop = comod_backprop
        self.normalize_gain_minmax_after = normalize_gain_minmax_after
        self.only_train_added_params = only_train_added_params
        self.only_retrain_last = only_retrain_last
        self.compute_gain_once_with_train_set = compute_gain_once_with_train_set
                
    def forward(self, x, tasks=None):
        
        if self.controller is not None and (self.is_attention or self.is_comodulation ):
            if tasks is None:
                controller_params = self.controller.get_controller_params(x)
            else:
                controller_params = self.controller.get_controller_params(x,tasks)
        else:
            controller_params = None
        
        if self.is_comodulation:
            if self.compute_gain_once_with_train_set:
                if self.training:
                    if self.comod_backprop:
                        if tasks is  None:
                            new_gain = self.update_gain(x,tasks)   
                            self.network.set_gain(new_gain)
                        else:
                            new_gain = self.update_gain_single_task(x,tasks)   
                            self.network.set_gain(new_gain)
                    else:
                        self.network.set_gain(1)
                else:
                    assert self.tasks_gain is not None, "Need to compute gain with training set first"
                    gains = self.tasks_gain[tasks]
                    self.network.set_gain(gains)
                    
            elif (not self.training) or (self.training and self.comod_backprop):
                with torch.no_grad():
                    if tasks is  None:
                        new_gain = self.update_gain(x,tasks)   
                        self.network.set_gain(new_gain)
                    else:
                        new_gain = self.update_gain_single_task(x,tasks)   
                        self.network.set_gain(new_gain)

            else:
                self.network.set_gain(1)
        return *self.network(x, self.backprop_mod, controller_params=controller_params,tasks=tasks),controller_params
        #return self.network(x)
    
    def set_modulation_training(self, train):
        self.network.set_modulation_training(train)
    
    def set_layer_to_modulate(self, layer):
        self.network.set_layer_to_modulate(layer)
        
    def get_layer_to_modulate(self):
        return self.network.get_layer_to_modulate()
    
    def update_gain(self, data,tasks):            
            
        self.network.set_gain(1)
        #self.network.decoder.ln = nn.Identity()
        self.network.eval()
        batch_size = data.shape[0]
        modulation_params = self.controller.get_controller_params(data)    
        
        NUM_CONTROLLER_PARAMS = self.controller.num_distinct_inputs
        if modulation_params.shape[0] == self.controller.num_distinct_inputs:
            modulation_params = modulation_params.unsqueeze(0)
            modulation_params = modulation_params.repeat((batch_size,1,1,1,1))
        else:
            assert modulation_params.shape[0] == NUM_CONTROLLER_PARAMS*batch_size, "Wrong number of controller params"
            modulation_params = einops.rearrange(modulation_params,'(b t) c h w -> b t c h w',b=batch_size,t = NUM_CONTROLLER_PARAMS)
                
        modulation_params = modulation_params.unsqueeze(0)
        modulation_params = modulation_params.repeat((self.steps,1,1,1,1,1))
        modulation_params = einops.rearrange(modulation_params,' s b t c h w -> (b s t) c h w')

        modulations = self.modulation_for_gain.unsqueeze(0).repeat((NUM_CONTROLLER_PARAMS,1,1,1,1))

        modulations = modulations.unsqueeze(0).repeat((batch_size,1,1,1,1,1))
        modulations = einops.rearrange(modulations,' b t s c h w -> (b s t) c h w')
        
        data_for_modulator = self.network.encode_until_modulator(data)
        layer_activity = self.network.decode(data_for_modulator,modulations,controller_params=modulation_params,num_tasks=NUM_CONTROLLER_PARAMS,steps=self.steps)
        del data_for_modulator
        
        layer_activity = einops.rearrange(layer_activity, '(b s t) c-> b s t c',b=batch_size,t = NUM_CONTROLLER_PARAMS)#(batch_size,self.steps, channels)
        layer_activity = layer_activity - layer_activity.mean((1)).unsqueeze(1)

        modulations = einops.rearrange(modulations.squeeze().unsqueeze(-1), '(b s t) c -> b s t c',b=batch_size,t = NUM_CONTROLLER_PARAMS)
        modulations = modulations - modulations.mean(1).unsqueeze(1)

        gain = (layer_activity * modulations).mean(1)
            
        gain  = einops.rearrange(gain, 'b t c -> (b t) c')
        
        if self.normalize_gain_minmax_after:
            gain = (gain - gain.min(dim=1).values.unsqueeze(1))/(gain.max(dim=1).values.unsqueeze(1) - gain.min(dim=1).values.unsqueeze(1) + 1e-7)
            
            
        return gain.detach()

    def update_gain_single_task(self, data,tasks):            
            
        self.network.set_gain(1)
        #self.network.decoder.ln = nn.Identity()
        self.network.eval()
        batch_size = data.shape[0]
        modulation_params = self.controller.get_controller_params(data,tasks)      
        modulation_params = modulation_params.unsqueeze(0)
        modulation_params = modulation_params.repeat((self.steps,1,1,1,1))
        modulation_params = einops.rearrange(modulation_params,' s b c h w -> (b s) c h w')

        modulations = self.modulation_for_gain.unsqueeze(0).repeat((batch_size,1,1,1,1))
        modulations = einops.rearrange(modulations,' b s c h w -> (b s) c h w')
        
        data_for_modulator = self.network.encode_until_modulator(data)
        layer_activity = self.network.decode(data_for_modulator,modulations,controller_params=modulation_params,steps=self.steps)
        del data_for_modulator
        
        layer_activity = einops.rearrange(layer_activity, '(b s) c-> b s c',b=batch_size)#(batch_size,self.steps, channels)
        layer_activity = layer_activity - layer_activity.mean((1)).unsqueeze(1)
            
        modulations = einops.rearrange(modulations.squeeze().unsqueeze(-1), '(b s) c -> b s c',b=batch_size)
        modulations = modulations - modulations.mean(1).unsqueeze(1)

        gain = (layer_activity * modulations).mean(1)

        if self.normalize_gain_minmax_after:
            gain = (gain - gain.min(dim=1).values.unsqueeze(1))/(gain.max(dim=1).values.unsqueeze(1) - gain.min(dim=1).values.unsqueeze(1) + 1e-7)
            
            
        return gain.detach()

    def compute_gain_on_training_set(self, train_dataloader):            
        with torch.no_grad():
            self.network.set_gain(1)
            #self.network.decoder.ln = nn.Identity()
            self.network.eval()
            tasks_gains = torch.zeros((self.num_tasks,self.network.num_features_decoder)).to(self.device)
            tasks_counts = torch.zeros((self.num_tasks,)).to(self.device)
            for (images,labels,tasks) in train_dataloader:
                images,labels,tasks = images.to(self.device),labels.to(self.device),tasks.to(self.device)
                gains = self.update_gain_single_task(images,tasks)
                for t in torch.unique(tasks):
                    tasks_gains[t] += gains[tasks ==t].sum(0)
                    tasks_counts[t] += (tasks ==t).sum()
            tasks_gains = tasks_gains/tasks_counts.unsqueeze(1)
            self.tasks_gain = tasks_gains

    def get_modulator_params(self):
        params = []
        if self.controller is not None:
            params += list(self.controller.parameters())
        else:
            params += list(self.network.get_modulator_params())
        
        if self.retrain_everything:
            params += list(self.network.parameters())
        elif  self.retrain_last:
            print("Retraining last layer")
            params += list(self.network.classifiers.parameters())
        return params 

    def get_parameters_for_optimizer(self):
        if not( self.is_comodulation or self.is_attention):
            if self.only_train_added_params:
                return self.network.get_added_parameters()
            elif self.only_retrain_last:
                return self.network.classifiers.parameters()
            else:
                return self.network.get_pretrain_parameters()
        else:
            return self.get_modulator_params()
        
    def get_np_params(self):
        if self.controller is not None:
            return list(self.controller.parameters()) + (list(self.network.out.parameters()))
        else:
            return list(self.network.get_modulator_params()) + (list(self.network.out.parameters()))
            
