import torch
import torch.nn as nn
import numpy as np
import clip
import gc
import sys
import os
import yaml
import time
import torch.nn.functional as F
from PIL import Image
from models.softgroup.model import SoftGroup

from lib.ap_helper.ap_helper_fcos import parse_predictions

from torch.profiler import record_function
from utils.util import cuda_cast
from models.long_clip.model import longclip
class ThreeLayerMLP(nn.Module):
    """A 3-layer MLP with normalization and dropout."""

    def __init__(self, dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim, 1, bias=False),
            # nn.LayerNorm(dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, dim, 1, bias=False),
            # nn.LayerNorm(dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv1d(dim, out_dim, 1)
        )

    def forward(self, x):
        """Forward pass, x can be (B, dim, N)."""
        return self.net(x)

class JointNet(nn.Module):
    def __init__(self, num_class, class_name,
                 input_feature_dim=0, width=1,
                 num_proposal=128, num_target=32, num_rec_other=16, num_locals=-1, vote_factor=1, sampling="vote_fps",
                 no_caption=False, use_topdown=False, query_mode="corner",
                 use_lang_classifier=True, use_bidir=False, no_reference=False,
                 emb_size=300, hidden_size=256, args=None, cfg=None, vocabulary=None):
        super().__init__()
        self.num_class = num_class
        self.class_name = class_name
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.vote_factor = vote_factor
        self.sampling = sampling
        self.use_lang_classifier = use_lang_classifier
        self.use_bidir = use_bidir
        self.no_reference = no_reference
        self.no_caption = no_caption
        self.num_target = num_target
        self.num_other = num_proposal - num_target
        self.num_rec_other = min(num_rec_other, self.num_other)
        # self.vocab_size = 3235 if args.dataset == "nr3d" else len(vocabulary["idx2word"])
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.args = args
        self.cfg = cfg
        # print(len(vocabulary["idx2word"]))

        if args.pretrain_model_on:
            if args.pretrain_model == "softgroup":
                # --------- SoftGroup PROPOSAL GENERATION ---------
                self.softgroup = SoftGroup(
                    **cfg.model.softgroup,
                    num_proposal=num_proposal)
                for _p in self.softgroup.parameters():
                    _p.requires_grad = False

        self.longclip_model = longclip.load("models/long_clip/checkpoints/LongCLIP-L/longclip-L.pt")[0]
        for param in self.longclip_model.parameters():
            param.requires_grad = False
        self.lang_cls = ThreeLayerMLP(768, self.num_class)

        if args.distribute:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    # @cuda_cast
    def forward(self, data_dict, use_tf=True, is_eval=False):
        """ Forward pass of the network

        Args:
            data_dict: dict
                {
                    point_clouds,
                    lang_feat
                }

                point_clouds: Variable(torch.cuda.FloatTensor)
                    (B, N, 3 + input_channels) tensor
                    Point cloud to run predicts on
                    Each point in the point-cloud MUST
                    be formated as (x, y, z, features...)
        Returns:
            end_points: dict
        """
        
        clip_tokens = data_dict["clip_token"].flatten(start_dim=0, end_dim=1)
        lang_feat = self.longclip_model.encode_text(clip_tokens).float()
        data_dict["lang_scores"] = self.lang_cls(lang_feat.unsqueeze(-1)).squeeze() 

        return data_dict
    
  