import torch.nn as nn
import torch
from .loss import Criterion
import gorilla
import os 

class Detector(nn.Module):

    def __init__(self, config, pretrain):
        super().__init__()

        from .maft import MAFT
        self.detector = MAFT(**config)
        self.ema_detector = MAFT(**config)

        if pretrain:
            print('Load pretrain......')
            gorilla.load_checkpoint(self.detector, pretrain, strict=False)
            gorilla.load_checkpoint(self.ema_detector, pretrain, strict=False)

        for param in self.ema_detector.parameters():
            param.detach_()
        total_params  = sum(p.numel() for p in self.detector.parameters() if p.requires_grad)
        total_params  = sum(p.numel() for p in self.ema_detector.parameters() if p.requires_grad)
        
        self.criterion = Criterion(**config['criterion'], num_class=config['num_class'])
        

    def forward(
        self, batch, mode='loss'
    ):

        if mode == 'loss':
            with torch.no_grad():
                t_outputs, insts = self.ema_detector(
                    batch,
                    mode,
                    't'
                )
            s_outputs, _ = self.detector(
                    batch,
                    mode,
                    's'
            )

            loss, loss_dict = self.criterion(s_outputs, insts, t_outputs)
            return loss, loss_dict 
        elif mode == 'predict':
            outputs = self.detector(
                    batch,
                    mode,
            )

        return outputs
