

import os

import numpy as np
from torchvision import transforms
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Function
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from collections import OrderedDict
from transformers import BatchFeature
import clip
import copy
import lpips
from transformers import ViTImageProcessor, ViTModel

## CSD_CLIP
def convert_weights_float(model: nn.Module):
    """Convert applicable model parameters to fp32"""

    def _convert_weights_to_fp32(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.float()
            if l.bias is not None:
                l.bias.data = l.bias.data.float()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.float()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.float()

    model.apply(_convert_weights_to_fp32)

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None



class ProjectionHead(nn.Module):
    def __init__(
            self,
            embedding_dim,
            projection_dim,
            dropout=0
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

def convert_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith("module."):
            k = k.replace("module.", "")
        new_state_dict[k] = v
    return new_state_dict
def init_weights(m): # TODO: do we need init for layernorm?
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.normal_(m.bias, std=1e-6)

class Metric(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_preprocess = None

    def load_image(self, image_path):
        with open(image_path, 'rb') as f:
            image = Image.open(f).convert("RGB")
        return image

    def load_image_path(self, image_path):
        if isinstance(image_path, str):
            # should be a image folder path
            images_file = os.listdir(image_path)
            images = [os.path.join(image_path, image) for image in images_file if
                      image.endswith(".jpg") or image.endswith(".png")]
            images.sort()
        if isinstance(image_path[0], str):
            images = [self.load_image(image) for image in images]
        elif isinstance(image_path[0], np.ndarray):
            images = [Image.fromarray(image) for image in image_path]
        elif isinstance(image_path[0], Image.Image):
            images = image_path
        else:
            raise Exception("Invalid input")
        return images

    def preprocess_image(self, image, **kwargs):
        if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
            input_data = self.load_image_path(image)
            input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
            input_data = torch.stack(input_data)
        elif os.path.isfile(image):
            input_data = self.load_image(image)
            input_data = self.image_preprocess(input_data, **kwargs)
            input_data = input_data.unsqueeze(0)
        elif isinstance(image, torch.Tensor):
            raise Exception("Unsupported input")
        return input_data

class Clip_Basic_Metric(Metric):
    def __init__(self):
        super().__init__()
        self.tensor_preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            # transforms.rescale
            transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
        self.image_preprocess = transforms.Compose([
            transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

class Clip_metric(Clip_Basic_Metric):

    @torch.no_grad()
    def __init__(self, target_style_prompt: str, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
                 bath_size=8, alpha=0.5):
        super().__init__()
        # digital_illustration_styles = ["anime", "caricature", "cartoon", "childrens", "comics", "commercial", "concept", "fantasy", "fashion", "fine art", "line art", "realism", "retro"]
        wikiart_styles = [
            "Realism",
            "Impressionism",
            "Romanticism",
            "Expressionism",
            "Post-Impressionism",
            "Baroque",
            "Art Nouveau",
            "Surrealism",
            "Symbolism",
            "Abstract Expressionism",
            "Neoclassicism",
            "Naïve Art (Primitivism)",
            "Rococo",
            "Cubism",
            "Northern Renaissance",
            "Academicism",
            "Pop Art",
            "Mannerism (Late Renaissance)",
            "Minimalism",
            "Conceptual Art",
            "Abstract Art",
            "Neo-Expressionism",
            "Art Informel",
            "Early Renaissance",
            "Ukiyo-e",
            "Magic Realism",
            "High Renaissance",
            "Contemporary Realism",
            "Color Field Painting",
            "Orientalism",
            "Lyrical Abstraction",
            "Fauvism",
            "Op Art",
            "Contemporary",
            "Neo-Impressionism",
            "Art Deco",
            "Social Realism",
            "Neo-Pop Art",
            "Naturalism",
            "Kitsch",
            "Neo-Romanticism",
            "Ink and wash painting",
            "Concretism",
            "Socialist Realism",
            "Post-Minimalism",
            "Neo-Dada",
            "Hard Edge Painting",
            "Transavantgarde",
            "Constructivism",
            "Pointillism",
            "Regionalism",
            "Tachisme",
            "Native Art",
            "Feminist Art",
            "Tenebrism",
            "Pictorialism",
            "Street art",
            "Art Brut",
            "Outsider art",
            "Dada",
            "Sōsaku hanga",
            "Biedermeier",
            "New Realism",
            "Proto Renaissance",
            "Light and Space",
            "Divisionism",
            "Shin-hanga",
            "Futurism",
            "New European Painting",
            "Fantastic Realism",
            "Photorealism",
            "Luminism",
            "Tonalism",
            "Post-Painterly Abstraction",
            "American Realism",
            "Figurative Expressionism",
            "Precisionism",
            "Romanesque",
            "Neo-Figurative Art",
            "Postcolonial art",
            "International Gothic",
            "Nouveau Réalisme",
            "Hyper-Realism",
            "Neo-Minimalism",
            "Kinetic Art",
            "Orphism",
            "Classicism",
            "Metaphysical art",
            "Byzantine",
            "Neo-baroque",
            "Synthetic Cubism",
            "Muralism",
            "Digital Art",
            "Classical Realism",
            "Japonism",
            "Cubo-Futurism",
            "Environmental (Land) Art",
            "Cloisonnism",
            "Neoplasticism",
            "Purism",
            "Spatialism",
            "Zen",
            "Modernismo",
            "Middle Byzantine",
            "Documentary photography",
            "P&D (Pattern and Decoration)",
            "Street Photography",
            "Neo-Geo",
            "Late Byzantine/Palaeologan Renaissance",
            "Suprematism",
            "Hellenistic",
            "Safavid Period",
            "Sumi-e (Suiboku-ga)",
            "Fantasy Art",
            "Costumbrismo",
            "Confessional Art",
            "Analytical Cubism",
            "New Kingdom",
            "Intimism",
            "Action painting",
            "Verism",
            "Early Byzantine",
            "Lowbrow Art",
            "Neo-Rococo",
            "Performance Art",
            "Post-classic",
            "Macedonian Renaissance",
            "Neo-Concretism",
            "Analytical Realism",
            "Classical",
            "Mozarabic",
            "Transautomatism",
            "Indian Space painting",
            "Ottoman Period",
            "Stuckism",
            "Synthetism",
            "Coptic art",
            "Moscow school of icon painting",
            "Modernism",
            "Automatic Painting",
            "Mechanistic Cubism",
            "Junk Art",
            "Viking art",
            "New Ink Painting",
            "Maximalism",
            "Neo-Suprematism",
            "Nanga",
            "Mosan art",
            "Queer art",
            "Poster Art Realism",
            "Lettrism",
            "Cartographic Art",
            "Timurid Period",
            "New Casualism",
            "Nihonga",
            "Komnenian style",
            "Yamato-e",
            "Existential Art",
            "Archaic",
            "Cyber Art",
            "Gongbi",
            "Gothic",
            "Joseon Dynasty",
            "Renaissance",
            "Fiber art",
            "Novgorod school of icon painting",
            "Superflat",
            "Excessivism",
            "Latin Empire of Constantinople",
            "Yoruba",
            "Cubo-Expressionism",
            "Medieval Art",
            "Folk art",
            "Art Singulier",
            "Tubism",
            "Neo-Orthodoxism",
            "Galicia-Volyn school",
            "Miserablism",
            "Neo-Byzantine",
            "Mail Art",
            "Hyper-Mannerism (Anachronism)",
            "Kyiv school of icon painting",
            "Mughal",
            "Synchromism",
            "Sots Art",
            "3rd Intermediate Period",
            "Abbasid Period",
            "Rayonism",
            "Kanō school style",
            "Ilkhanid",
            "Nas-Taliq",
            "Yaroslavl school of icon painting",
            "Crusader workshop",
            "Geometric",
            "Amarna",
            "Cretan school of icon painting",
            "Ero guro",
            "New Medievialism",
            "Site-specific art",
            "Severe Style",
            "Middle Kingdom",
            "Perceptism",
            "Pskov school of icon painting",
            "Vologda school of icon painting",
            "Spectralism",
            "Late Period",
            "Ptolemaic",
            "Vladimir school of icon painting",
            "Chernihiv school of icon painting",
            "Sky Art",
            "Early Dynastic",
            "Macedonian school of icon painting",
            "Early Christian",
            "Graffiti Art",
            "Old Kingdom",
            "2nd Intermediate Period",
            "New media art",
            "Stroganov school of icon painting",
        ]
        self.styles = wikiart_styles + ["Photograph"]
        self.device = device
        self.alpha = alpha
        self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        self.tokenizer = self.processor.tokenizer
        self.image_processor = self.processor.image_processor
        self.style_class_features = self.get_text_features(self.styles).cpu()
        # self.noise_prompt_features = self.get_text_features("Noise")
        self.model.eval()
        self.batch_size = bath_size
        self.ref_style_features = self.get_text_features(target_style_prompt)

        self.ref_image_style_prototype = None


    def get_text_features(self, text):
        prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device)
        prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
        prompt_features = F.normalize(prompt_features, p=2, dim=-1)
        return prompt_features

    def get_image_features(self, images):
        # if isinstance(image, torch.Tensor):
        #     self.tensor_transform(image)
        # else:
        #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
        images = self.load_image_path(images)
        if isinstance(images, torch.Tensor):
            images = self.tensor_preprocess(images)
            data = {"pixel_values": images}
            image_features = BatchFeature(data=data, tensor_type="pt")
        else:
            image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
                                                                                                non_blocking=True)

        image_features = self.model.get_image_features(**image_features).to(self.device)
        image_features = F.normalize(image_features, p=2, dim=-1)
        return image_features

    def img_text_similarity(self, image_features, text=None):
        if text is not None:
            prompt_feature = self.get_text_features(text)
            if isinstance(text, str):
                prompt_feature = prompt_feature.repeat(len(image_features), 1)
        else:
            prompt_feature = self.ref_style_features

        similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
        return similarity_each.mean()

    def forward(self, output_imgs, prompt=None):
        image_features = self.get_image_features(output_imgs)
        # print(image_features)
        style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
        if prompt is not None:
            content_score = self.img_text_similarity(image_features, prompt)

            score = self.alpha * style_score + (1 - self.alpha) * content_score
            return {"score": score, "style_score": style_score, "content_score": content_score}
        else:
            return {"style_score": style_score}

    # for style loss
    def get_style_prototype(self, image_features):
        self.style_class_features = self.style_class_features.to(self.device)
        style_prototype = torch.einsum("np, cp -> nc", image_features, self.style_class_features)
        style_prototype = style_prototype * 50
        style_prototype = F.softmax(style_prototype, dim=-1)
        style_prototype = style_prototype.mean(dim=0)
        style_prototype = F.normalize(style_prototype, p=2, dim=-1)
        self.style_class_features = self.style_class_features.cpu()
        return style_prototype
    @torch.no_grad()
    def define_ref_image_style_prototype(self, ref_image_path: str):
        ref_image_features = self.get_image_features(ref_image_path)
        self.ref_image_style_prototype = self.get_style_prototype(ref_image_features)
        self.ref_image_style_prototype = self.ref_image_style_prototype.cpu()

    def get_style_loss(self, output_images):
        assert self.ref_image_style_prototype is not None, "Please define the reference image style prototype first"
        self.ref_image_style_prototype = self.ref_image_style_prototype.to(self.device)
        image_features = self.get_image_features(output_images)
        style_prototype = self.get_style_prototype(image_features)
        similarity = style_prototype @ self.ref_image_style_prototype
        loss = 1 - similarity
        # loss += F.mse_loss(style_prototype, self.ref_image_style_prototype)
        self.ref_image_style_prototype = self.ref_image_style_prototype.cpu()
        return loss.mean()

class CSD_CLIP(Clip_Basic_Metric):
    """backbone + projection head"""
    def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "code/diffusion/stable_diffusion/libs/CSD/weights/CSD-checkpoint.pth", device="cuda",
                 alpha=0.5, **kwargs):
        super(CSD_CLIP, self).__init__()
        self.alpha = alpha
        self.content_proj_head = content_proj_head
        self.device = device
        if name == 'vit_large':
            clipmodel, _ = clip.load("ViT-L/14")
            self.backbone = clipmodel.visual
            self.embedding_dim = 1024
        elif name == 'vit_base':
            clipmodel, _ = clip.load("ViT-B/16")
            self.backbone = clipmodel.visual
            self.embedding_dim = 768
            self.feat_dim = 512
        else:
            raise Exception('This model is not implemented')

        convert_weights_float(self.backbone)
        self.last_layer_style = copy.deepcopy(self.backbone.proj)
        if content_proj_head == 'custom':
            self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
            self.last_layer_content.apply(init_weights)

        else:
            self.last_layer_content = copy.deepcopy(self.backbone.proj)

        self.backbone.proj = None
        self.backbone.requires_grad_(False)
        self.last_layer_style.requires_grad_(False)
        self.last_layer_content.requires_grad_(False)
        self.backbone.eval()

        if ckpt_path is not None:
            self.load_ckpt(ckpt_path)
        self.to("cpu")

    def load_ckpt(self, ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        state_dict = convert_state_dict(checkpoint['model_state_dict'])
        msg = self.load_state_dict(state_dict, strict=False)
        print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")

    @property
    def dtype(self):
        return self.backbone.conv1.weight.dtype

    def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
        if isinstance(input_data, torch.Tensor):
            input_data = self.tensor_preprocess(input_data)
        elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
            input_data = self.load_image_path(input_data)
            input_data = [self.image_preprocess(image) for image in input_data]
            input_data = torch.stack(input_data)
        elif os.path.isfile(input_data):
            input_data = self.load_image(input_data)
            input_data = self.image_preprocess(input_data)
            input_data = input_data.unsqueeze(0)
        input_data = input_data.to(self.device)
        style_output = None

        feature = self.backbone(input_data)
        if get_style:
            style_output = feature @ self.last_layer_style
            # style_output = style_output.mean(dim=0)
            style_output = nn.functional.normalize(style_output, dim=-1, p=2)

        content_output=None
        if get_content:
            if feature_alpha is not None:
                reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
            else:
                reverse_feature = feature
            # if alpha is not None:
            if self.content_proj_head == 'custom':
                content_output =  self.last_layer_content(reverse_feature)
            else:
                content_output = reverse_feature @ self.last_layer_content
            content_output = nn.functional.normalize(content_output, dim=-1, p=2)

        return feature, content_output, style_output


    @torch.no_grad()
    def define_ref_image_style_prototype(self, ref_image_path: str):
        self.to(self.device)
        _, _, self.ref_style_feature = self.get_image_features(ref_image_path)
        self.to("cpu")
        # self.ref_style_feature = self.ref_style_feature.mean(dim=0)
    @torch.no_grad()
    def forward(self, styled_data):
        self.to(self.device)
        # get_content_feature = original_data is not None
        _, content_output, style_output = self.get_image_features(styled_data, get_content=False)
        style_similarities = style_output @ self.ref_style_feature.T
        mean_style_similarities = style_similarities.mean(dim=-1)
        mean_style_similarity = mean_style_similarities.mean()

        max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
        max_style_similarity = max_style_similarities_v.mean()

        # if get_content_feature:
        #     _, original_content, _ = self.get_image_features(original_data, get_style=False, get_content=True)
        #     content_similarity = torch.einsum("nc, nc -> n", content_output, original_content)
        #     content_similarity = content_similarity.mean()
        #     score = self.alpha * style_similarity + (1 - self.alpha) * content_similarity
        #     return {"score": score, "style_score": style_similarity, "content_score": content_similarity}

        # content_similarity = content_output @ self.ref_content_feature
        # score = alpha * style_similarity + (1 - alpha) * content_similarity
        # return {"score": score, "style_score": style_similarity, "content_score": content_similarity}
        self.to("cpu")
        return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
                "CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}

    def get_style_loss(self, styled_data):
        _, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
        style_similarity = (style_output @ self.ref_style_feature).mean()
        loss = 1 - style_similarity
        return loss.mean()

class LPIPS_metric(Metric):
    def __init__(self, type="vgg", device="cuda"):
        super(LPIPS_metric, self).__init__()
        self.lpips = lpips.LPIPS(net=type)
        self.device = device
        self.image_preprocess = transforms.Compose([
            transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.to("cpu")

    @torch.no_grad()
    def forward(self, img1, img2):
        self.to(self.device)
        differences = []
        for i in range(0, len(img1), 50):
            img1_batch = img1[i:i+50]
            img2_batch = img2[i:i+50]
            img1_batch = self.preprocess_image(img1_batch).to(self.device)
            img2_batch = self.preprocess_image(img2_batch).to(self.device)
            differences.append(self.lpips(img1_batch, img2_batch).squeeze())
        differences = torch.cat(differences)
        difference = differences.mean()
        # similarity = 1 - difference
        self.to("cpu")
        return {"LPIPS_content_difference": difference,  "LPIPS_content_difference_details": differences}

class Vit_metric(Metric):
    def __init__(self, device="cuda"):
        super(Vit_metric, self).__init__()
        self.device = device
        self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
        self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
        self.to("cpu")
    def get_image_features(self, images):
        # if isinstance(image, torch.Tensor):
        #     self.tensor_transform(image)
        # else:
        #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
        images = self.load_image_path(images)
        batch_size = 20
        all_image_features = []
        for i in range(0, len(images), batch_size):
            image_batch = images[i:i+batch_size]
            if isinstance(image_batch, torch.Tensor):
                image_batch = self.tensor_preprocess(image_batch)
                data = {"pixel_values": image_batch}
                image_processed = BatchFeature(data=data, tensor_type="pt")
            else:
                image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
            image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
            image_features = F.normalize(image_features, p=2, dim=-1)
            all_image_features.append(image_features)
        all_image_features = torch.cat(all_image_features)
        return all_image_features

    @torch.no_grad()
    def content_metric(self, img1, img2):
        self.to(self.device)
        if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
            img1 = self.get_image_features(img1)
        if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
            img2 = self.get_image_features(img2)
        similarities = torch.einsum("nc, nc -> n", img1, img2)
        similarity = similarities.mean()
        # self.to("cpu")
        return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}

    # style
    @torch.no_grad()
    def define_ref_image_style_prototype(self, ref_image_path: str):
        self.to(self.device)
        self.ref_style_feature = self.get_image_features(ref_image_path)
        self.to("cpu")
    @torch.no_grad()
    def style_metric(self, styled_data):
        self.to(self.device)
        if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
            style_output = styled_data
        else:
            style_output = self.get_image_features(styled_data)
        style_similarities = style_output @ self.ref_style_feature.T
        mean_style_similarities = style_similarities.mean(dim=-1)
        mean_style_similarity = mean_style_similarities.mean()

        max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
        max_style_similarity = max_style_similarities_v.mean()

        # self.to("cpu")
        return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
                "Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
    @torch.no_grad()
    def forward(self, styled_data, original_data=None):
        self.to(self.device)
        styled_features = self.get_image_features(styled_data)
        ret ={}
        if original_data is not None:
            content_metric = self.content_metric(styled_features, original_data)
            ret["Vit_content"] = content_metric
        style_metric = self.style_metric(styled_features)
        ret["Vit_style"] = style_metric
        self.to("cpu")
        return ret



class StyleContentMetric(nn.Module):
    def __init__(self, style_ref_image_folder, device="cuda"):
        super(StyleContentMetric, self).__init__()
        self.device = device
        self.clip_metric = CSD_CLIP(device=device)
        self.ref_image_file = os.listdir(style_ref_image_folder)
        self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
        self.ref_image_file.sort()
        self.clip_metric.define_ref_image_style_prototype(style_ref_image_folder)
        self.vit_metric = Vit_metric(device=device)
        self.vit_metric.define_ref_image_style_prototype(style_ref_image_folder)
        self.lpips_metric = LPIPS_metric(device=device)
        self.to("cpu")

    def forward(self, styled_data, original_data=None):
        ret ={}
        csd_score = self.clip_metric(styled_data)
        csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
        torch.cuda.empty_cache()
        ret["Style_CSD"] = csd_score
        vit_score = self.vit_metric(styled_data, original_data)
        torch.cuda.empty_cache()
        vit_style = vit_score["Vit_style"]
        vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
        ret["Style_VIT"] = vit_style

        if original_data is not None:
            vit_content = vit_score["Vit_content"]
            ret["Content_VIT"] = vit_content
            lpips_score = self.lpips_metric(styled_data, original_data)
            torch.cuda.empty_cache()
            ret["Content_LPIPS"] = lpips_score


        for type_key, type_value in ret.items():
            for key, value in type_value.items():
                if isinstance(value, torch.Tensor):
                    if value.numel() == 1:
                        ret[type_key][key] = round(value.item(), 4)
                    else:
                        ret[type_key][key] = value.tolist()
                        ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]

        self.to("cpu")
        ret["ref_image_file"] = self.ref_image_file.tolist()
        return ret
