import clip
import numpy as np
import torch
import torch.nn as nn
import os
import random


class SemanticExtractor(nn.Module):
    def __init__(self, device = "cuda:0",
                 use_clip=False,
                 clip_model ="ViT-L/14",
                 annotation_dir="/annotations/",
                 masking_prob=0.0,
                 augmentation_semantic=False,
                 simple_semantics=False):
        super().__init__()
        self.device = device
        self.use_clip = use_clip
        self.annotation_dir = annotation_dir
        self.augmentation_semantic = augmentation_semantic
        self.simple_semantics = simple_semantics
        self.model = self.load_and_freeze_clip(clip_model) if use_clip else nn.Identity()
        self.embedding_size = 768

        self.masking_prob = masking_prob

        if not self.use_clip:
            file = "" if simple_semantics else "_multiple"
            self.object_features = torch.load(os.path.join(self.annotation_dir, "object_classes{}.pt".format(file)))
            self.index_to_obj = {value["index"]: key for key, value in self.object_features.items()}
            self.index_to_feat = {value["index"]: value["feat"] for key, value in self.object_features.items()}
            self.relationship_features = torch.load(os.path.join(self.annotation_dir, "relationship_classes.pt"))

    def load_and_freeze_clip(self, clip_version):
        clip_model, clip_preprocess = clip.load(clip_version, device=self.device, jit=False)  # Must set jit=False for training
        clip.model.convert_weights(clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        clip_model.eval()
        for p in clip_model.parameters():
            p.requires_grad = False

        return clip_model
    def extract_text(self, text):
        token_text = clip.tokenize(text).to(self.device)
        text_features = self.model.encode_text(token_text).float()
        return text_features

    def extract_and_batch(self, texts, features__):
        features = []
        for text in texts:
            feat = features__[text] if text in features__ else self.extract_one([text])[0]
            features.append(feat)
        return torch.stack(features, dim=0).to(self.device)

    def get_feature_object(self, texts):
        """

        :param texts: list of strings
        :return:
        """
        return self.extract_and_batch(texts, self.object_features)

    def get_from_index(self, indices):
        if self.simple_semantics:
            return torch.stack([self.index_to_feat[i] for i in indices]).to(self.device)
        elif self.training and self.augmentation_semantic:
            res = []
            for i in indices:
                feat_idx = random.randint(0, self.index_to_feat[i].shape[0]-1)
                f = self.index_to_feat[i][feat_idx]
                if random.uniform(0,1) < self.masking_prob:
                    f = torch.zeros_like(f)
                res.append(f)
            return torch.stack(res).to(self.device)
        else:
            return torch.stack([self.index_to_feat[i][0] for i in indices]).to(self.device)

    def get_feature_object_idx(self, index):
        """

        :param index: get from labels indices of the object classes
        :return:
        """
        return self.extract_and_batch([self.index_to_obj[i] for i in index], self.object_features)

    def get_feature_relationships(self, texts):
        return self.extract_and_batch(texts, self.relationship_features)

    def forward(self, texts):
        if not self.use_clip and texts in self.object_features:
            return self.extract_and_batch(texts, self.object_features)
        elif not self.use_clip and  texts in self.relationship_features:
            return self.extract_and_batch(texts, self.relationship_features)
        else:
            return self.extract_text(texts)

if __name__ == '__main__':
    from configs.paths import *
    semantic_extractor = SemanticExtractor(device="cuda:0", use_clip=False, clip_model="ViT-L/14",
                                           annotation_dir=annotations_dir)

    semantic_extractor = semantic_extractor.eval()

    r = semantic_extractor.get_from_index([0, 1, 10])
    print(r)
