from collections import OrderedDict

import torch.nn as nn


class ContinualMultiheadMLP(nn.Module):
    def __init__(self, shared_structure, in_size, out_size, device):
        super().__init__()
        self.out_size = out_size
        self.in_size = in_size
        self.device = device
        self.linear_heads = {}
        self.shared_net = None
        self.linear_layer_size = shared_structure[-1]
        self.max_task_id = -1

    def adapt_new_task(self, task_id=None):
        if task_id is None or task_id not in self.linear_heads:
            if task_id is None:
                self.max_task_id += 1
            else:
                self.max_task_id = max(self.max_task_id, task_id)
            self.linear_heads[task_id] = nn.Linear(self.linear_layer_size, self.out_size, device=self.device)

    def forward(self, task_data, task_id=None):
        if task_id is None:
            task_id = self.max_task_id
        shared_out = self.shared_net(task_data)
        final_out = self.linear_heads[task_id](shared_out)
        return final_out

    def loss(self, task_data, task_labels, label_loss, task_id=None, is_test=False):
        return label_loss(self.forward(task_data, task_id), task_labels)
