# import html
# import re

# import ftfy
import torch
from transformers import AutoTokenizer, T5EncoderModel

import clip



class T5TextEncoder:
    def __init__(self, 
                 device, 
                #  use_text_preprocessing,
                 local_files_only, 
                 from_pretrained=None, 
                 model_max_length=120):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(
            from_pretrained,
            local_files_only = local_files_only,
            legacy=False
        )
        
        self.model = T5EncoderModel.from_pretrained(
            from_pretrained,
            local_files_only=local_files_only
        ).eval()
        for p in self.model.parameters():
            p.requires_grad = False
        # self.use_text_processing = use_text_preprocessing
        self.model_max_length = model_max_length
        # self.tokenizer.to(self.device)
        self.model.to(self.device)

    @torch.no_grad()
    def get_text_embeddings(self, texts):
        text_tokens_and_mask = self.tokenizer(
            texts,
            max_length=self.model_max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        input_ids = text_tokens_and_mask["input_ids"].to(self.device)
        attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
        with torch.no_grad():
            text_encoder_embs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )["last_hidden_state"].detach()
        return text_encoder_embs, attention_mask


class CLIPTextEncoder:
    def __init__(self, 
                device, 
            #  use_text_preprocessing,
                clip_version='ViT-B/32'):
        # self.clip_model = 
        clip_model, clip_preprocess = clip.load(clip_version, device=device, jit=False)  
        # Added support for cpu
        if str(device) != "cpu":
            clip.model.convert_weights(
                clip_model)  # Actually this line is unnecessary since clip by default already on float16
            # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu

        # Freeze CLIP weights
        clip_model.eval()
        self.clip_model = clip_model
        for p in self.clip_model.parameters():
            p.requires_grad = False

        # self.model_max_length = model_max_length
    
    @ torch.no_grad()
    def get_text_embeddings(self, raw_text):
        device = next(self.clip_model.parameters()).device
        text = clip.tokenize(raw_text, truncate=True).to(device)
        feat_clip_text = self.clip_model.encode_text(text).float()
        attention_mask = torch.ones((feat_clip_text.shape[0], 1), device=device)
        return feat_clip_text.unsqueeze(1), attention_mask
    
class T5CLIPTextEncoder:
    def __init__(self, device, 
                #  use_text_preprocessing,
                 local_files_only, 
                 t5_version, 
                 clip_version,
                 max_length=120):
        self.clip_model = CLIPTextEncoder(device, clip_version)
        self.t5_model = T5TextEncoder(device, local_files_only, t5_version, max_length)

    def get_text_embeddings(self, raw_text):
        t5_emb, att_mask = self.t5_model.get_text_embeddings(raw_text)
        clip_emb, _ = self.clip_model.get_text_embeddings(raw_text)
        att_mask = torch.cat([torch.ones_like(att_mask[:, 0:1]), att_mask], dim=-1)
        return  t5_emb, clip_emb.unsqueeze(1), att_mask

# def basic_clean(text):
#     text = ftfy.fix_text(text)
#     text = html.unescape(html.unescape(text))
#     return text.strip()


# BAD_PUNCT_REGEX = re.compile(
#     r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
# )  # noqa


# def clean_caption(caption):
#     import urllib.parse as ul

#     from bs4 import BeautifulSoup

#     caption = str(caption)
#     caption = ul.unquote_plus(caption)
#     caption = caption.strip().lower()
#     caption = re.sub("<person>", "person", caption)
#     # urls:
#     caption = re.sub(
#         r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
#         "",
#         caption,
#     )  # regex for urls
#     caption = re.sub(
#         r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
#         "",
#         caption,
#     )  # regex for urls
#     # html:
#     caption = BeautifulSoup(caption, features="html.parser").text

#     # @<nickname>
#     caption = re.sub(r"@[\w\d]+\b", "", caption)

#     # 31C0—31EF CJK Strokes
#     # 31F0—31FF Katakana Phonetic Extensions
#     # 3200—32FF Enclosed CJK Letters and Months
#     # 3300—33FF CJK Compatibility
#     # 3400—4DBF CJK Unified Ideographs Extension A
#     # 4DC0—4DFF Yijing Hexagram Symbols
#     # 4E00—9FFF CJK Unified Ideographs
#     caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
#     caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
#     caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
#     caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
#     caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
#     caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
#     caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#     #######################################################

#     # все виды тире / all types of dash --> "-"
#     caption = re.sub(
#         r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",  # noqa
#         "-",
#         caption,
#     )

#     # кавычки к одному стандарту
#     caption = re.sub(r"[`´«»“”¨]", '"', caption)
#     caption = re.sub(r"[‘’]", "'", caption)

#     # &quot;
#     caption = re.sub(r"&quot;?", "", caption)
#     # &amp
#     caption = re.sub(r"&amp", "", caption)

#     # ip adresses:
#     caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)

#     # article ids:
#     caption = re.sub(r"\d:\d\d\s+$", "", caption)

#     # \n
#     caption = re.sub(r"\\n", " ", caption)

#     # "#123"
#     caption = re.sub(r"#\d{1,3}\b", "", caption)
#     # "#12345.."
#     caption = re.sub(r"#\d{5,}\b", "", caption)
#     # "123456.."
#     caption = re.sub(r"\b\d{6,}\b", "", caption)
#     # filenames:
#     caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)

#     #
#     caption = re.sub(r"[\"\']{2,}", r'"', caption)  # """AUSVERKAUFT"""
#     caption = re.sub(r"[\.]{2,}", r" ", caption)  # """AUSVERKAUFT"""

#     caption = re.sub(BAD_PUNCT_REGEX, r" ", caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT
#     caption = re.sub(r"\s+\.\s+", r" ", caption)  # " . "

#     # this-is-my-cute-cat / this_is_my_cute_cat
#     regex2 = re.compile(r"(?:\-|\_)")
#     if len(re.findall(regex2, caption)) > 3:
#         caption = re.sub(regex2, " ", caption)

#     caption = basic_clean(caption)

#     caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption)  # jc6640
#     caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption)  # jc6640vc
#     caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption)  # 6640vc231

#     caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
#     caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
#     caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
#     caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
#     caption = re.sub(r"\bpage\s+\d+\b", "", caption)

#     caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption)  # j2d1a2a...

#     caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)

#     caption = re.sub(r"\b\s+\:\s+", r": ", caption)
#     caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
#     caption = re.sub(r"\s+", " ", caption)

#     caption.strip()

#     caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
#     caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
#     caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
#     caption = re.sub(r"^\.\S+$", "", caption)

#     return caption.strip()


# def text_preprocessing(text, use_text_preprocessing: bool = True):
#     if use_text_preprocessing:
#         # The exact text cleaning as was in the training stage:
#         text = clean_caption(text)
#         text = clean_caption(text)
#         return text
#     else:
#         return text.lower().strip()