import clip
import torch
import pandas as pd
import numpy as np

def have_overlap(seg1, seg2):
    if seg1[0] > seg2[1] or seg2[0] > seg1[1]:
        return False
    else:
        return True


def get_overlap(seg1, seg2):
    overlap_len = max(0, min(seg1[1], seg2[1]) - max(seg1[0], seg2[0]))
    return overlap_len


def load_and_freeze_clip(clip_version, device='cpu'):
    clip_model, clip_preprocess = clip.load(clip_version, device=device,
                                            jit=False)  # Must set jit=False for training
    clip.model.convert_weights(
        clip_model)  # Actually this line is unnecessary since clip by default already on float16

    # Freeze CLIP weights
    clip_model.eval()
    for p in clip_model.parameters():
        p.requires_grad = False

    return clip_model


def load_and_freeze_t5_encoder(model_name='google/flan-t5-xxl', device='cpu'):
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    
    tokenizer = AutoTokenizer.from_pretrained(model_name).to(device)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

    model.encoder.eval()

    for param in model.encoder.parameters():
        param.requires_grad = False

    return tokenizer, model 


# def encode_text(clip_model, raw_text, force_empty_zero=True):
#     device = next(clip_model.parameters()).device
#     # raw_text - list (batch_size length) of strings with input text prompts
#     texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length]
#     text_embedding = clip_model.encode_text(texts).float() # [bs, 512]
#     if force_empty_zero:  # force empty string to have zero embedding, same as being masked out in original MDM
#         empty_text = [text == '' for text in raw_text]
#         text_embedding[empty_text, :] = 0
#     return text_embedding


def encode_text(clip_model, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
    device = next(clip_model.parameters()).device
    embed_dim = 512
    batch_size = len(raw_text)

    if not text_sep:
        texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
        text_embedding = clip_model.encode_text(texts).float()  # [B, 512]
        if force_empty_zero:
            empty_text = [t == '' for t in raw_text]
            text_embedding[empty_text, :] = 0
        return text_embedding


    raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
    if sep_mode == 0:
        split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
    elif sep_mode == 1:
        split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
    split_df = split_df.fillna('').astype(str).applymap(str.strip)

    split_df = split_df.reindex(columns=range(max_segs), fill_value='')
    
    segs_matrix = split_df.values
    segs_flat = segs_matrix.reshape(-1).tolist()

    text_mask = (segs_matrix == '').astype(bool)
    text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

    tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
    text_embedding = clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
    text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

    if force_empty_zero:
        text_embedding[text_mask] = 0

    return text_embedding, text_mask


def encode_text_t5(tokenizer, t5_model, raw_text, force_empty_zero=True):
    device = next(t5_model.parameters()).device
    texts = tokenizer(raw_text, return_tensors="pt").to(device) # [bs, context_length]
    with torch.no_grad():
        encoder_outputs = t5_model.encoder(**texts)
    text_embedding = encoder_outputs.last_hidden_state
    if force_empty_zero:  # force empty string to have zero embedding, same as being masked out in original MDM
        empty_text = [text == '' for text in raw_text]
        text_embedding[empty_text, :] = 0
    return text_embedding


def compose_texts_with_and(texts):
    texts = sorted(texts)
    return ' and '.join(texts)

def dict_to_args(dict_args):
    from dataclasses import make_dataclass, asdict
    dynamic_class = make_dataclass('DynamicMotionModelArgs', fields=[(key, type(dict_args[key])) for key in dict_args])
    args = dynamic_class(**dict_args)

    return args
