from models import register_model
import torch.nn as nn
import torch
from models.base_model import BaseModel
from models.transformer_encoder_input import TransformerEncoderInput
from models import build_model

import logging

log = logging.getLogger(__name__)


@register_model("pt_downstream_model_custom")
class PtDownstreamModelCustom(BaseModel):
    def __init__(self):
        super(PtDownstreamModelCustom, self).__init__()

    def forward(self, inputs, src_key_mask, positions, rep_from_layer=-1):
        outs = self.upstream(inputs, src_key_mask, positions, intermediate_rep=True)
        ## Returned output is a tuple if attention_weights=True
        if isinstance(outs, tuple): 
            # Extract attention weights from output 
            outs, weights = outs
            h = outs[:,0,:]
            h = self.linear_out(h)
            return h, weights
        else : 
            # Use all outputs for output 
            h = outs[:,0,:]
            h = self.linear_out(h)
            return h

    def build_model(self, cfg):
        self.cfg = cfg
        upstream_cfg = self.cfg.upstream_cfg
        upstream = build_model(upstream_cfg) 
        
        ## If start from random init
        random_init = self.cfg.get('random_init', False)
        if not random_init:
            log.info("Loading upstream model")
            upstream_model_path = self.cfg.upstream_path
            upstream_torch = torch.load(upstream_model_path)
            upstream.load_state_dict(upstream_torch["model"])
        else:
            log.info("Randomly initialized model")

        ## If freezing upstream model
        if "frozen_upstream" in cfg and cfg.frozen_upstream:
            log.info("Freezing upstream parameters")
            for name, param in upstream.named_parameters():
                 print(name)
                 param.requires_grad = False
        
        self.upstream = upstream

        self.linear_out = nn.Linear(cfg.hidden_dim, cfg.get('output_dim', 1)) 
