import tqdm
from transformers import AutoModelForSequenceClassification
from transformers import BertModel,BertTokenizer
from torch.utils.data import DataLoader
from enum import Enum
import torch


class FeatureExtractorType(Enum):
    FineTune = 'fine-tune'
    PreTrain = 'pre-train'


class FeatureExtractor(object):
    def __init__(self, num_class, weight_path=None, device="cuda", feature_extractor_type=FeatureExtractorType.FineTune):
        if feature_extractor_type == FeatureExtractorType.FineTune:
            self.feature_extractor = AutoModelForSequenceClassification.from_pretrained(weight_path, output_hidden_states=True)
        elif feature_extractor_type == FeatureExtractorType.PreTrain:
            self.feature_extractor = BertModel.from_pretrained('bert-base-uncased')
        self.fe_type = feature_extractor_type
        self.feature_extractor.to(device+":0")
        self.model_device = device+":0"
        self.save_device = device+":1"
        self.labels = [c for c in range(num_class)]

    def cal_class_distribution(self, class_x_set, bs=64, is_logits=False):
        loader = DataLoader(class_x_set, batch_size=bs)
        feature_list = []
        for batch in tqdm.tqdm(loader):
            batch = {k: v.to(self.model_device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = self.feature_extractor(**batch)
                if is_logits:
                    features = outputs.logits
                else:
                    if self.fe_type == FeatureExtractorType.PreTrain:
                        features = torch.mean(outputs.last_hidden_state, dim=1).unsqueeze(dim=0)
                    elif self.fe_type == FeatureExtractorType.FineTune:
                        # features = torch.mean(outputs.hidden_states[-2], dim=1)
                        # hidden_states(num_layer, batch_size, max_seq, feature_dim)
                        features = outputs.hidden_states[-2][:, 0, :].detach()
                feature_list.append(features.to(self.save_device))
        features = torch.cat(feature_list, dim=0)
        if len(class_x_set) != features.shape[0]:
            raise Exception('cal_class_distribution error!')
        return features

    def extractor_features_from_dst(self, x, y, batch_size=64, is_split_by_class=True, is_logits=False):
        x.set_format("torch")
        if is_split_by_class:
            feature_distri_dict = {l: None for l in self.labels}
            for label in self.labels:
                class_x = x.filter(lambda e, i: y[i]==label, with_indices=True)
                feature_distri_dict[label] = self.cal_class_distribution(class_x, batch_size, is_logits)
            return feature_distri_dict
        else:
            return self.cal_class_distribution(x, batch_size, is_logits)


if __name__ == '__main__':
    a = torch.tensor([1,2,3]).to("cuda:0")
    b = torch.tensor([2,3,4]).to("cuda:1")
    print(a+b)