import clip 
import torchvision 
import torch 
import os 
import matplotlib.pyplot as plt 
from utils import Euclidean
import torch.nn as nn
def check_vec_exsit(save_dir):
    if(os.path.exists(save_dir)):
        return True
def encode_models(model_name):
    if(model_name=="resnet50"):
        inversion_model = torchvision.models.resnet50(num_classes=10, pretrained=False).to('cuda').eval()
        static_dict = torch.load(r'D:\Living_and_Study_In_University\Dataset\CIFA-10\resnet50.pth')
        inversion_model.load_state_dict(static_dict)
        inversion_model.fc=nn.Identity()
        inversion_model=inversion_model.to('cuda')
        final_shape=2048
    elif(model_name=="clip"):
        inversion_model=clip.load("ViT-B/32", device="cuda")
        final_shape=512
    return inversion_model,final_shape
class Mean_Dis(nn.Module):
    def __init__(self, inverion_model) -> None:
        super().__init__()
        self.inversion_model=inverion_model
        self.flatten=nn.Flatten()
    def forward(self,x):
        x_new=self.inversion_model(x)
        x_fla=self.flatten(x)
        return torch.concat([x_new,x_fla],axis=-1)
def check_duplications(dupli, cars_duplicate):
    if(dupli.max()>1):
        dupli=dupli/255
    if(type(cars_duplicate)==list):
        cars_duplicate=torch.stack(cars_duplicate,dim=0)
    if(len(dupli.shape)==3):
        dupli=torch.tensor(dupli).permute(2,0,1).unsqueeze(0)
    value,idxs=torch.sort(Euclidean(dupli.flatten(start_dim=1),cars_duplicate.flatten(start_dim=1)),descending=False)
    plt.imshow(dupli.permute(0,2,3,1).numpy().squeeze())
    plt.show()
    plt.figure(figsize=(10,10))
    print(value)
    for i in range(min(150,len(idxs))):
        plt.subplot(1,1,i+1)
        plt.imshow(cars_duplicate[idxs[i]].permute(1,2,0).numpy())
        plt.axis('off')

def list_to_tensor(torch_device, frog_ori):
    if(type(frog_ori)==list):
        for i in range(len(frog_ori)):
            frog_ori[i]=torchvision.transforms.ToTensor()(frog_ori[i]).to(torch_device)
        frog_ori=torch.stack(frog_ori,axis=0)
    return frog_ori
