import torch
import os
import warnings
from src.nlp.finetune_heads.PoolingModule import PoolingModule


class BaseModel(torch.nn.Module):

    def __init__(self, cfg, device):
        super(BaseModel, self).__init__()

        self.cfg = cfg
        self.device = device

        self.lm_head = None # Language modeling head, only used in next token prediction, initialised in experiments
        self.svd_concat= None # Initialised if needed

    def init_pooling(self, dtype):
        # Set up pooling
        self.pooling = PoolingModule(self.cfg.pooling, cfg=self.cfg, hidden_dim=self.hidden_dim, dtype=dtype)
        if self.cfg.svd.add_to_pooling:
            self.singular_pooler = PoolingModule("singular", cfg=self.cfg) # Used to perform singular pooling NEXT TO the normal pooling module

    def _distributed_setup(self):
        """
        Set up distributed computing.
        """
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        warnings.warn(f"Default provided device {self.device} is being overridden by distributed setup, new device {self.local_rank}.")
        self.device = self.local_rank
        # put modules to local rank device
        self.backbone= self.backbone.to(self.local_rank)
        self.linear_head = self.linear_head.to(self.local_rank)

        if self.svd_concat is not None:
            self.svd_concat = self.svd_concat.to(self.local_rank)

        if self.lm_head is not None:
            self.lm_head = self.lm_head.to(self.local_rank)

        if self.cfg.pooling == "attention_pool":
            self.pooling.attention_pooling = self.pooling.attention_pooling.to(self.local_rank)

        if self.cfg.pooling == "weighted_avg":
            self.pooling.weighted_average_pooling = self.pooling.weighted_average_pooling.to(self.local_rank)

        return self.local_rank, self.global_rank

    def set_linear_finetune(self, value):
        """
        Call this function before model finetuning is done"""
        self.linear_finetune = value

    def get_dtype(self):
        """
        Get the dtype of the model.
        """
        return self.linear_head.weight.dtype

    def change_linear_head(self, linear_head):
        """
        Change the linear head of the model.
        """
        assert self.linear_finetune == False, "Cannot change linear head after finetuning has started."
        self.linear_head = linear_head