import torch

class BaseOCGLearner():
    def __init__(self, model, neighborhood_processor, args):

        self.net = model
        self.neighborhood_processor = neighborhood_processor
        self.device = f'cuda:{args.gpu}'

        self.n_passes = args.epochs
        if any(param.requires_grad for param in self.net.parameters()):
            self.opt = torch.optim.Adam(self.net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            self.ce = torch.nn.functional.cross_entropy
            
        self.observed_classes = []
        self.output_mask = torch.zeros(args.n_cls).bool()

        self.center_features = args.center_features and args.backbone in ['UGCN', 'GRNF', 'UMIXED']
        if self.center_features:
            self.feature_mean = None
            self.n_samples = 0

    def predict_labels(self, g, node_ids):
        self.net.eval()
        with torch.no_grad():
            input_data = self.neighborhood_processor.extract(g, node_ids)
            if self.center_features:
                input_data = self.online_centering(input_data)
            output = self.net(input_data)[:, self.output_mask]
            _, predicted_labels = torch.max(output, dim=1)
            inverse_map = {v: k for k, v in self.label_mapping.items()}
            predicted_labels_adj = torch.tensor([inverse_map[label.item()] for label in predicted_labels], device=self.device)
        return predicted_labels_adj

    def observe(self, g, train_ids, train_labels):
        self.net.train()
        labels = self.before_passes(train_labels)
        input_data = self.neighborhood_processor.extract(g, train_ids)
        if self.center_features:
            input_data = self.online_centering(input_data)
        for _ in range (self.n_passes):
            self.before_training(input_data, labels)
            self.net.zero_grad()
            output = self.net(input_data)[:, self.output_mask]
            loss = self.ce(output, labels)
            loss_aux = self.compute_auxiliary_loss(input_data, labels, output)
            if loss_aux is not None:
                loss += loss_aux
            loss.backward()
            self.after_loss(input_data, labels)
            self.opt.step()
        self.after_passes(input_data, labels)

    def before_passes(self, labels):
        unique_labels, _ = labels.unique().sort()
        for label in unique_labels:
            if label.item() not in self.observed_classes:
                self.observed_classes.append(label.item())
                self.output_mask[label] = True

        self.label_mapping = {old_label: new_label for new_label, old_label in enumerate(sorted(self.observed_classes))}
        adjusted_labels = labels.clone()
        for old_label, new_label in self.label_mapping.items():
            adjusted_labels[labels == old_label] = new_label
        return adjusted_labels

    def before_training(self, input_data, labels):
        pass

    def after_loss(self, input_data, labels):
        pass

    def after_passes(self, input_data, labels):
        pass

    def compute_auxiliary_loss(self, input_data, labels, logits):
        return None
    
    def online_centering(self, features):
        if self.net.training:
            if self.feature_mean is None:
                self.feature_mean = features.mean(dim=0)
                self.n_samples = features.shape[0]
            else:
                with torch.no_grad():
                    self.net.linear.bias[self.output_mask] += (1 / (self.n_samples + features.shape[0])) * self.net.linear.weight[self.output_mask] @ (features.sum(dim=0) - features.shape[0] * self.feature_mean)
                self.feature_mean += (features.sum(dim=0) - features.shape[0] * self.feature_mean) / (self.n_samples + features.shape[0])
                self.n_samples += features.shape[0]
        return features - self.feature_mean
    