import os
import torch
import torch.nn as nn
from .st_gcn.st_gcn import Model
import numpy as np
import pandas as pd
import random

class STGCNFeatureExtractor(nn.Module):
    def __init__(self):
        super(STGCNFeatureExtractor, self).__init__()

        weight_path = os.path.join(os.path.dirname(__file__), 'gcn_weight.pth')

        self.model = Model(
            in_channels=3,
            num_class=60,
            graph_args={'layout': 'ntu-rgb+d', 'strategy': 'uniform'},
            edge_importance_weighting=True
        )
        checkpoint = torch.load(weight_path, map_location='cpu', weights_only= True)
        state_dict = checkpoint.get('state_dict', checkpoint)
        new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
        self.model.load_state_dict(new_state_dict, strict=False)
        #self.model.eval()  


    def forward(self, x):

        features = self.model.extract_feature(x)[1]  # (N, 256, T, V, M)
        feature_vector = features.mean(dim=[2, 3, 4])  # (N, 256)
        return feature_vector
