import torch
import torch.nn as nn
import einops



class TaskController(nn.Module):
    def __init__(self, parameter_shape, task_groups,hidden_dim = 256, num_layers=2,guided_init=False,params_per_group=False,device="cpu",use_end_bias=False,do_act_end=False,act_end="nn.Sigmoid()"):
        super(TaskController, self).__init__()
        self.feature_channels = parameter_shape[0]
        self.params_per_group = params_per_group
        self.device = device
        self.num_distinct_tasks = sum([len(group) for group in task_groups])
        if params_per_group:
            self.num_distinct_inputs = len(task_groups)
        else:
            self.num_distinct_inputs = sum([len(group) for group in task_groups])
        
        task_embeddings = []
        cur_inp_dim = self.num_distinct_inputs

        for l in range(num_layers-1):
            layer = nn.Linear(cur_inp_dim,hidden_dim)
            task_embeddings.append(layer)
            cur_inp_dim = hidden_dim
            task_embeddings.append(nn.ReLU())
            
        layer = nn.Linear(hidden_dim,self.feature_channels,bias=use_end_bias)
        if guided_init:
            nn.init.zeros_(layer.weight)
            nn.init.ones_(layer.bias)

        task_embeddings.append(layer)
        
        if do_act_end:
            task_embeddings.append(eval(act_end)())
            
        self.task_embeddings = nn.Sequential(*task_embeddings)
        self.tasks_rep = self.get_all_tasks_repr().float()
            
        
    def get_all_tasks_repr(self):
        tasks_repr = torch.eye(self.num_distinct_inputs).to(self.device)
        tasks_repr = tasks_repr.unsqueeze(-1)
        tasks_repr = einops.rearrange(tasks_repr,"b t n -> b (t n)")
        return tasks_repr

    def get_controller_params(self,x,tasks=None):
        if tasks is None:
            tasks_embeds = self.task_embeddings(self.tasks_rep)
        else:
            tasks_embeds = self.task_embeddings(self.tasks_rep[tasks])
        out = tasks_embeds.unsqueeze(-1).unsqueeze(-1)
        return out
            
    def forward(self, tasks,x):
        tasks_repr = torch.nn.functional.one_hot(tasks,self.num_distinct_inputs)
        tasks_repr = tasks_repr.unsqueeze(-1)
        tasks_repr = einops.rearrange(tasks_repr,"b t n -> b (t n)").float()
        out = self.task_embeddings(tasks_repr).unsqueeze(-1).unsqueeze(-1)
        return out
    
