import json
import torch
import torch.nn as nn
import numpy as np

from utils import aux_data
from models.base_models import load_backbone


class OcrnBaseModel(nn.Module):
    def __init__(self, dataset, args):
        super(OcrnBaseModel, self).__init__()

        self.args = args
        self.num_obj = len(dataset.objs)
        self.num_attr = len(dataset.attrs)
        self.num_aff = dataset.num_aff

        # submodules
        if args.data_type == "feature":
            self.backbone = None
            self.feat_dim = dataset.feature_dim
        else:
            self.backbone, self.feat_dim = load_backbone(args.backbone_type, args.backbone_weight)

        # prior information
        prior_info = torch.load(f"features/{args.data}_{args.backbone_type}/obj_prior.t7")
        self.register_buffer("mean_obj_features",
            prior_info["mean_obj_features"] )  # (n_obj, dim)
        self.register_buffer("obj_frequence", 
            prior_info["freqency"] )  # (n_obj,)
        assert len(prior_info["freqency"].size())==1
        

        # NOTE: the category affordance/attribute will be opened later
        self.register_buffer("category_attr",
            torch.randn(self.num_obj, self.num_attr).float() )
        self.register_buffer("category_aff",
            torch.randn(self.num_obj, self.num_aff).float() )

        print(f"CA: attr={self.category_attr.shape}, aff={self.category_aff.shape}")

        # loss weight
        if args.loss_class_weight:
            obj_wgt, attr_wgt, aff_wgt = aux_data.load_loss_weight(args.data)
            self.register_buffer("obj_loss_wgt",  obj_wgt)
            self.register_buffer("attr_loss_wgt", attr_wgt)
            self.register_buffer("aff_loss_wgt",  aff_wgt)
        else:
            self.obj_loss_wgt, self.attr_loss_wgt, self.aff_loss_wgt = None, None, None

        if args.bce_pos_weight is not None:

            with open(f"utils/aux_data/{args.data}_bce_pos_weight.json", "r") as fp:
                pos_weight = json.load(fp)
                self.register_buffer("pos_weight_attr", 
                    args.bce_pos_weight * torch.Tensor(pos_weight["attr"]).float())
                self.register_buffer("pos_weight_aff",  
                    args.bce_pos_weight * torch.Tensor(pos_weight["aff"]).float())

        else:
            self.pos_weight_attr = None
            self.pos_weight_aff = None
    
    
        # losses
        self.attr_bce = nn.BCEWithLogitsLoss(weight=self.attr_loss_wgt, pos_weight=self.pos_weight_attr)
        self.aff_bce = nn.BCEWithLogitsLoss(weight=self.aff_loss_wgt, pos_weight=self.pos_weight_aff)
