import torch
import torch.nn as nn
from lib.models import VanillaFlava
from lib.models.incremental_classifier import IncrementalClassifier
from lib.data.base import VLInputs 

class FlavaEWCCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = VanillaFlava()
        self.incremental_classifier = IncrementalClassifier(self.feature_extractor.d_model, n_output_classes)
        self.parameters_history = []
        self.fisher_scores_history = []
        
    def forward(self, inputs: VLInputs):
        hidden_states = self.feature_extractor(inputs)
        logits = self.incremental_classifier(hidden_states[:, 0])
        return logits
         
    def adaptation(self, n_output_classes: int, **kwargs):
        self.parameters_history.append(self.get_parameters())
        self.fisher_scores_history.append(self.get_fisher_scores(**kwargs))
        self.incremental_classifier.adaptation(n_output_classes)
        
    def get_parameters(self):
        parameters = dict()
        for name, param in self.feature_extractor.named_parameters():
            parameters[name] = param.data.clone()
        return parameters
    
    def get_fisher_scores(self, 
        dataloader: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        loss: nn.Module
    ):
        fisher_scores = dict()
        
        device = next(iter(self.parameters())).device
        
        self.train()
        optimizer.zero_grad()
        
        for inputs, targets in dataloader:
            inputs.to(device)
            logits = self(inputs)
            J = loss(logits, targets.to(device))
            J.backward()
            
        for name, param in self.feature_extractor.named_parameters():
            fisher_scores[name] = param.grad.data.clone()**2 if param.grad is not None else None
        optimizer.zero_grad()
        
        return fisher_scores  