import torch
import torch.nn as nn
from .RACONET import RacoNet
from .GIAO import giao



##############################################################
#                                                            #
###################### train model ###########################
#                                                            #
##############################################################

class RacoNetClassify(nn.Module):
    def __init__(self, 
                coder_in_channels,
                coder_out_channels,
                coder_in_size,
                coder_out_size,
                coder_parameters,
                classify_in_channels,
                classify_num,
                classify_out_channels,
                classify_out_size
                ):
        super(RacoNetClassify, self).__init__()
        self.encoder = giao.GIAO(
            in_channels=coder_in_channels, 
            img_size=coder_in_size, 
            out_channels=coder_out_channels, 
            out_img_size=coder_out_size
        )
        weights = torch.load(coder_parameters, map_location=torch.device('cpu'), weights_only=False)
        weights_dict = {}
        for k, v in weights.items():
            new_k = k.replace('module.', '') if 'module' in k else k
            weights_dict[new_k] = v
        self.encoder.load_state_dict(weights_dict)

        for param in self.encoder.parameters():
            param.requires_grad = False

        self.raconet = RacoNet.RacoNet(n_channels=classify_in_channels,
            n_classes=classify_out_channels,
            num_classes=classify_num,
            output_size=classify_out_size
        )

    def forward(self, img, truth):
        encoder, _ = self.encoder(truth)
        out_img, logit = self.raconet(img)
        return encoder, out_img, logit





