import torch
import copy
from .base_learner import BaseOCGLearner

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)
        self.center_features = False
        self.batch_size = args.batch_size
        self.gpu = args.gpu
        self.observe = self.observe_GCN if args.backbone == 'GCN' else self.observe_features

    def observe_features(self, g, train_ids, train_labels):
        labels = self.before_passes(train_labels)
        batched_input_data = []
        for i in range(0, len(train_ids), self.batch_size):
            batch_ids = train_ids[i:i + self.batch_size]
            batched_input_data.append(self.neighborhood_processor.extract(g, batch_ids))
            self.empty_gpu_cache()
        input_data = torch.cat(batched_input_data, dim=0)
        for _ in range (self.n_passes):
            self.net.train()
            self.net.zero_grad()
            output = self.net(input_data)[:, self.output_mask]
            loss = self.ce(output, labels)
            loss.backward()
            self.opt.step()

    def observe_GCN(self, g, train_ids, train_labels):
        labels = self.before_passes(train_labels)
        for _ in range (self.n_passes):
            self.net.train()
            self.net.zero_grad()
            input_data = self.neighborhood_processor.extract(g, train_ids)
            output = self.net(input_data)[:, self.output_mask]
            loss = self.ce(output, labels)
            loss.backward()
            self.opt.step()

    def empty_gpu_cache(self, threshold=0.66):
        gpu_memory = torch.cuda.memory_reserved(self.gpu)
        total_memory = torch.cuda.get_device_properties(self.gpu).total_memory
        if gpu_memory > total_memory * threshold:
            torch.cuda.empty_cache()