from typing import Optional, Tuple, List

import torch
import torch.nn.functional as F
from clip.model import CLIP
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

from data_utils import collate_fn
from models import Phi


if torch.cuda.is_available():
    device = torch.device("cuda")
    dtype = torch.float16
else:
    device = torch.device("cpu")
    dtype = torch.float32


@torch.no_grad()
def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32,
                           num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts image features from a dataset using a CLIP model.
    """
    # Create data loader
    loader = DataLoader(dataset=dataset, batch_size=batch_size,
                        num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)

    index_features = []
    index_names = []
    try:
        print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}")
    except Exception as e:
        pass

    # Extract features
    for batch in tqdm(loader):
        images = batch.get('image')
        names = batch.get('image_name')
        if images is None:
            images = batch.get('reference_image')
        if names is None:
            names = batch.get('reference_name')

        images = images.to(clip_model.device)
        with torch.no_grad():
            batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images)
            index_features.append(batch_features.cpu())
            index_names.extend(names)

    index_features = torch.vstack(index_features)
    return index_features, index_names


@torch.no_grad()
def extract_text_features(
    dataset: Dataset,
    clip_model: CLIPTextModelWithProjection,
    batch_size: Optional[int] = 32,
    num_workers: Optional[int] = 10,
) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts image features from a dataset using a CLIP model.
    """
    # Create data loader
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    index_features = []
    index_names = []
    try:
        print(f"extracting textfeatures {dataset.__class__.__name__} - {dataset.split}")
    except Exception as e:
        pass

    # Extract features
    for batch in tqdm(loader):
        texts = batch.get("caption")
        names = batch.get("image_name")
        if texts is None:
            texts = batch.get("reference_text")
        if names is None:
            names = batch.get("reference_name")

        texts = clip.tokenize(texts, context_length=77).to(clip_model.device)
        with torch.no_grad():
            batch_features = clip_model(input_ids=texts).text_embeds
            index_features.append(batch_features.cpu())
            index_names.extend(names)

    index_features = torch.vstack(index_features)

    return index_features, index_names

@torch.no_grad()
def extract_image_features_BLIP(dataset: Dataset, blip_model, batch_size: Optional[int] = 32,
                           num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts image features from a dataset using a CLIP model.
    """
    # Create data loader
    loader = DataLoader(dataset=dataset, batch_size=batch_size,
                        num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)

    index_features = []
    index_names = []
    try:
        print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}")
    except Exception as e:
        pass

    # Extract features
    for batch in tqdm(loader):
        images = batch.get('image')
        names = batch.get('image_name')
        if images is None:
            images = batch.get('reference_image')
        if names is None:
            names = batch.get('reference_name')

        images = images.to(device)
        with torch.no_grad():
            #batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images)
            #batch_blip = blip_model.visual_encoder(images.type(torch.float16) )#.image_embeds
            batch_blip = blip_model.visual_encoder(images )#.image_embeds
            
            batch_features = blip_model.vision_proj(batch_blip[:,0,:])
        
            index_features.append(batch_features.cpu())
            index_names.extend(names)

    index_features = torch.vstack(index_features)
    return index_features, index_names

def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor:
    # Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py
    v1 = F.normalize(v1, dim=1)
    v2 = F.normalize(v2, dim=1)

    numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature)
    numerator = torch.cat((numerator, numerator), 0)
    joint_vector = torch.cat((v1, v2), 0)
    pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature)
    denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0)

    loss = -torch.mean(torch.log(numerator / denominator))

    return loss


@torch.no_grad()
def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts pseudo tokens from a dataset using a CLIP model and a phi model
    """
    data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
                             collate_fn=collate_fn)
    predicted_tokens = []
    names_list = []
    print(f"Extracting tokens using phi model")
    for batch in tqdm(data_loader):
        images = batch.get('image')
        names = batch.get('image_name')
        if images is None:
            images = batch.get('reference_image')
        if names is None:
            names = batch.get('reference_name')

        images = images.to(device)
        image_from_clip = clip_model(pixel_values=images.half())
        image_features=image_from_clip.image_embeds
        image_last_hidden = image_from_clip.last_hidden_state
 #       print('image_last_hidden', image_last_hidden.shape)
        
        if args.l2_normalize:
            image_features = F.normalize(image_features, dim=-1)
        batch_predicted_tokens = phi(image_features)
        predicted_tokens.append(batch_predicted_tokens.cpu())
        names_list.extend(names)
    predicted_tokens = torch.vstack(predicted_tokens)    
  #  print("predicted_tokens: ", predicted_tokens.shape)

    return predicted_tokens, names_list



@torch.no_grad()
def extract_pseudo_tokens_with_phi_BLIP(blip_model, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts pseudo tokens from a dataset using a CLIP model and a phi model
    """
    data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
                             collate_fn=collate_fn)
    predicted_tokens = []
    names_list = []
    print(f"Extracting tokens using phi model")
    for batch in tqdm(data_loader):
        images = batch.get('image')
        names = batch.get('image_name')
        if images is None:
            images = batch.get('reference_image')
        if names is None:
            names = batch.get('reference_name')

        images = images.to(device)
        blip_output = blip_model.visual_encoder(images )#.image_embeds
        #blip_output = blip_model.visual_encoder(images.type(torch.float16) )
        #if args.l2_normalize:
        #    image_features = F.normalize(image_features, dim=-1)
        image_features = blip_model.vision_proj(blip_output[:,0,:]) #F.normalize(blip_model.vision_proj(blip_output[:,0,:]),dim=-1)  
        
        batch_predicted_tokens = phi(image_features)
        predicted_tokens.append(batch_predicted_tokens.cpu())
        names_list.extend(names)

    predicted_tokens = torch.vstack(predicted_tokens)
    return predicted_tokens, names_list
@torch.no_grad()
def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]:
    """
    Extracts image features from a dataset using a CLIP model
    """
    data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False,
                             collate_fn=collate_fn)
    predicted_tokens = []
    names_list = []
    print(f"Extracting tokens using phi model")
    for batch in tqdm(data_loader):
        images = batch.get('image')
        names = batch.get('image_name')
        if images is None:
            images = batch.get('reference_image')
        if names is None:
            names = batch.get('reference_name')

        images = images.to(device)
        image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds

        #batch_predicted_tokens = phi(image_features)
        batch_predicted_tokens = image_features
        predicted_tokens.append(batch_predicted_tokens.cpu())
        names_list.extend(names)

    predicted_tokens = torch.vstack(predicted_tokens)
    return predicted_tokens, names_list

class CustomTensorDataset(Dataset):
    """
    Custom Tensor Dataset which yields image_features and image_names
    """



    def __init__(self, images: torch.Tensor, names: torch.Tensor):
        self.images = images
        self.names = names

    def __getitem__(self, index) -> dict:
        return {'image': self.images[index],
                'image_name': self.names[index]
                }

    def __len__(self):
        return len(self.images)


def get_templates():
    """
    Return a list of templates
    Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694
    """
    return [
        "This is a photo of a {}",
        "This photo contains a {}",
        "A photo of a {}",
        "This is an illustration of a {}",
        "This illustration contains a {}",
        "An illustrations of a {}",
        "This is a sketch of a {}",
        "This sketch contains a {}",
        "A sketch of a {}",
        "This is a diagram of a {}",
        "This diagram contains a {}",
        "A diagram of a {}",
        "A {}",
        "We see a {}",
        "{}",
        "We see a {} in this photo",
        "We see a {} in this image",
        "We see a {} in this illustration",
        "We see a {} photo",
        "We see a {} image",
        "We see a {} illustration",
        "{} photo",
        "{} image",
        "{} illustration",
    ]
