import json, os, io
from abc import ABCMeta, abstractmethod
from typing import List, Dict, Any, Optional
from PIL import Image
from pathlib import Path
from absl import logging
import pandas as pd
import numpy as np
import torch
import clip
from statistics import mean
from torchvision.io import read_image
from torchvision.utils import make_grid, save_image
from torchvision.transforms import transforms

from tools.helpers import pip_images_make_grid
import wandb
from libs.face_eval import FaceEvaluator

class FeedEvaluator(object):

    def __init__(self, case_files, output_path, device,
                process_index=0, num_processes=1, eval_face=False) -> None:
        self.case_files = case_files
        self.cases = self.init_cases()
        self.output_path = output_path
        self.device = device
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        self.process_index = process_index
        self.num_process = num_processes
    
        self.eval_face = eval_face
        if self.eval_face:
            self.face_model = FaceEvaluator(gpu_id=device.index or 0)
    
    def set_output_path(self, new_path):
        self.output_path = new_path

    @torch.no_grad()
    def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
        tokens = clip.tokenize(text).to(self.device)
        text_features = self.clip_model.encode_text(tokens).detach()

        if norm:
            text_features /= text_features.norm(dim=-1, keepdim=True)

        return text_features
    
    @torch.no_grad()
    def get_image_features(self, images:List[Image.Image], norm: bool = True) -> torch.Tensor:
        preprocessed_images =  [self.clip_preprocess(image).to(self.device) for image in images]
        preprocessed_images = torch.stack(preprocessed_images)
        image_features = self.clip_model.encode_image(preprocessed_images)
        if norm:
            image_features /= image_features.clone().norm(dim=-1, keepdim=True)

        return image_features
    
    def face_to_face_similarity(self, src_images:List[Image.Image], generated_images:List[Image.Image]):
        return self.face_model.face_similarity()

    def img_to_img_similarity(self, src_images:List[Image.Image], generated_images:List[Image.Image]):
        src_img_features = self.get_image_features(src_images)
        gen_img_features = self.get_image_features(generated_images)
        return (src_img_features @ gen_img_features.T).mean().item()

    def txt_to_img_similarity(self, text, generated_images):
        text_features    = self.get_text_features(text)
        gen_img_features = self.get_image_features(generated_images)
        return (text_features @ gen_img_features.T).mean().item()
    

    def get_img_embedding(self, img:Image.Image):
        """ 
        get clip image embedding
        """
        x = self.clip_preprocess(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            feat = self.clip_model.encode_image(x)
        feat /= feat.norm(dim=-1, keepdim=True)
        return feat

    def init_cases(self):
        ls = []
        for case in self.case_files:
            dic = json.load(open(case, 'r'))
            dic["name"] = Path(case).stem
            ls.append(dic)
        return ls
    
    def test_case(self, case) -> pd.DataFrame:
        os.makedirs(os.path.join(self.output_path, case["name"]), exist_ok=True)
        df = pd.DataFrame(columns=["caption", "clip_image_sim", "clip_caption_sim"])
        
        image_list = [s["path"] for s in case["source_group"]]
        mask_list = [s["mask"] for s in case["source_group"]]
        pil_image_list = [Image.open(p).convert('RGB') for p in image_list]
        
        wandb_img = []
        for p in case["caption_list"]:
            cur_images = self.gen_one(case["type"], p, image_list, mask_list, class_word=case["class_word"])
            wandb_img.append(cur_images[0])
            # test clip image sim
            clip_image_sim = self.img_to_img_similarity(pil_image_list, cur_images)
            
            # test clip caption sim
            clip_text_sim = self.txt_to_img_similarity(p, cur_images)
            
            # face sim
            if self.eval_face:
                face_sim = self.face_model.face_similarity(cur_images, pil_image_list)
            else:
                face_sim = -1
            
            # save images
            for i, img in enumerate(cur_images):
                img.save(os.path.join(self.output_path, case["name"], f"{p}_{i}.png"))
            
            # record to dataframe
            data = {
                "caption": p,
                "clip_image_sim": clip_image_sim,
                "clip_caption_sim": clip_text_sim,
                "face_sim": face_sim
                }
            df = pd.concat([df, pd.Series(data).to_frame().T], ignore_index=True)
            
        ### Combine image and upload to wandb
        cls_name = case["name"]
        grid_img = pip_images_make_grid(wandb_img, cols=5)
        wandb_grid_img = wandb.Image(grid_img, caption=cls_name)
        wandb.log({cls_name: wandb_grid_img})

        return df
    
    def test(self):
        # gen image for each case
        df_ls = []
        for i, case in enumerate(self.cases):
            if i % self.num_process != self.process_index:
                continue
            else:
                logging.info(f"process {self.process_index} is sampling {case['name']}")
                df_ls.append((case, self.test_case(case)))
        if len(df_ls) == 0:
            return
        # summarize
        df_image_sim_means = [tb[1]["clip_image_sim"].mean() for tb in df_ls]
        df_caption_sim_means = [tb[1]["clip_caption_sim"].mean() for tb in df_ls]
        df_face_sim_means = [tb[1]["face_sim"].mean() for tb in df_ls]
        logging.info(f"{df_image_sim_means}, {df_caption_sim_means}, {df_face_sim_means}")
        
        
        summary = {
            "image_sim": mean(df_image_sim_means),
            "text_sim": mean(df_caption_sim_means),
            "face_sim": mean(df_face_sim_means),
        }
        wandb.log({"eval": summary})
        
        detail_info = {}
        for case, df in df_ls:
            detail_info[case["name"]] = json.loads(df.to_json())
        # gen report
        with open(os.path.join(self.output_path, f"summary_{self.process_index}.json"), 'w') as json_f:
            json_f.write(json.dumps(detail_info, indent=4))


    @abstractmethod
    def gen_one(self, type:Optional[str], caption:str, ref_images:List[str], ref_masks:List[str]=None, class_word="")-> List[Image.Image]:
        """
        基于1个caption和ref_image生成图片
        return: 一个PIL.Image.Image的列表
        """
        raise NotImplementedError
    
    @abstractmethod
    def gather_files(self):
        pass
    
    
class SourceGroupEvaluator(FeedEvaluator):
    """
    same with FeedEvaluator, but use source_group to generate images
    """
    def __init__(self, case_files, output_path, device, 
                process_index=0, num_processes=1, eval_face=False) -> None:
        super().__init__(case_files, output_path, device, process_index, num_processes, eval_face)

    
    def test_case(self, case) -> pd.DataFrame:
        os.makedirs(os.path.join(self.output_path, case["name"]), exist_ok=True)
        df = pd.DataFrame(columns=["caption", "clip_image_sim", "clip_caption_sim"])
        
        image_list = [s["path"] for s in case["source_group"]]
        pil_image_list = [Image.open(p).convert('RGB') for p in image_list]
        wandb_img = []
        for p in case["caption_list"]:
            cur_images = self.gen_one(case["type"], p, case["source_group"], class_word=case["class_word"])
            wandb_img.append(cur_images[0])
            # test clip image sim
            clip_image_sim = self.img_to_img_similarity(pil_image_list, cur_images)
            
            # test clip caption sim
            clip_text_sim = self.txt_to_img_similarity(p, cur_images)
            
            # face sim
            if self.eval_face:
                face_sim = self.face_model.face_similarity(cur_images, pil_image_list)
            else:
                face_sim = -1
            
            # save images
            for i, img in enumerate(cur_images):
                img.save(os.path.join(self.output_path, case["name"], f"{p}_{i}.png"))
            
            # record to dataframe
            data = {
                "caption": p,
                "clip_image_sim": clip_image_sim,
                "clip_caption_sim": clip_text_sim,
                "face_sim": face_sim
                }
            df = pd.concat([df, pd.Series(data).to_frame().T], ignore_index=True)

        ### Combine image and upload to wandb
        cls_name = case["name"]
        grid_img = pip_images_make_grid(wandb_img, cols=5)
        wandb_grid_img = wandb.Image(grid_img, caption=cls_name)
        wandb.log({cls_name: wandb_grid_img})
        
        return df
    

    @abstractmethod
    def gen_one(self, type:Optional[str], caption:str, source_group:List[Dict], class_word="")-> List[Image.Image]:
        """
        基于1个caption和ref_image生成图片
        return: 一个PIL.Image.Image的列表
        """
        raise NotImplementedError
    

class SourceGroupEvaluatorMultiplier(FeedEvaluator):
    """
    same with SourceGroup, but use source_group to generate images,
    and support multiplier setting
    add multiprocess support:
    using stratgy of write all files to disk
    then using gathering file at main process and then log to wandb
    """
    def __init__(self, case_files, output_path, device, 
                process_index=0,
                num_processes=1, 
                eval_face=False) -> None:
        super().__init__(case_files, output_path, device, process_index, num_processes, eval_face)
        self.cfgs = [5,] 
        self.ucms = [(1., 0.), (1., 1.)]

    
    def test(self):
        # gen image for each case
        df_ls = []
        for i, case in enumerate(self.cases):
            if i % self.num_process != self.process_index:
                continue
            else:
                print(f"process {self.process_index} is sampling {case['name']}")
                df_ls.append((case, self.test_case(case)))

        if len(df_ls) == 0:
            return

        detail_info = {}
        for case, df in df_ls:
            detail_info[case["name"]] = json.loads(df.to_json())

        # gen report
        with open(os.path.join(self.output_path, f"summary_{self.process_index}.json"), 'w') as json_f:
            json_f.write(json.dumps(detail_info, indent=4))
    
    
    def gather_files(self):
        # load dfs
        df_ls = list()
        for i in range(min(len(self.cases), self.num_process)):
            dic = json.load(open(os.path.join(self.output_path, f"summary_{i}.json"), "r"))
            for k,v in dic.items():
                f = io.BytesIO()
                f.write(json.dumps(v).encode())
                f.seek(0)
                df_ls.append((k, pd.read_json(f)))
        
        # summarize
        df_image_sim_means = [tb[1]["clip_image_sim"].mean() for tb in df_ls]
        df_caption_sim_means = [tb[1]["clip_caption_sim"].mean() for tb in df_ls]
        df_face_sim_means = [tb[1]["face_sim"].mean() if tb[1].get("face_sim", None) is not None else -1 for tb in df_ls]
        logging.info(f"{df_image_sim_means}, {df_caption_sim_means}, {df_face_sim_means}")
        
        summary = {
            "image_sim": np.mean(df_image_sim_means),
            "text_sim": np.mean(df_caption_sim_means),
            "face_sim": np.mean(df_face_sim_means),
        }
        wandb.log({"eval": summary})
        # print(summary)
        for cfg in self.cfgs:
            for cm, ucm in self.ucms:
                for case in self.cases:
                    wandb_img_path = []
                    for p in case["caption_list"]:
                        wandb_img_path.append(os.path.join(self.output_path, f"cm{cm}_ucm{ucm}_cfg{cfg}",case["name"], f"{p}_0.png"))
                    transform = transforms.Compose([
                        transforms.ConvertImageDtype(dtype=torch.float),
                    ])
                    tensors = list(map(lambda x: transform(read_image(x)), wandb_img_path))
                    grid_img = make_grid(tensors, padding=5, nrow=3)
                    save_image(grid_img, os.path.join(self.output_path, f"cm{cm}_ucm{ucm}_cfg{cfg}",case["name"], "summary.png"))
                    cls_name = case["name"] + f"cm{cm}_ucm{ucm}_cfg{cfg}"
                    wandb_grid_img = wandb.Image(grid_img, caption=cls_name)
                    wandb.log({cls_name: wandb_grid_img})
        
    
    def test_case(self, case) -> pd.DataFrame:
        df = pd.DataFrame(columns=["caption", "clip_image_sim", "clip_caption_sim"])
        
        image_list = [s["path"] for s in case["source_group"]]
        pil_image_list = [Image.open(p).convert('RGB') for p in image_list]
        for cfg in self.cfgs:
            for cm, ucm in self.ucms:
                for p in case["caption_list"]:
                    cur_images = self.gen_one(case["type"], p, case["source_group"], class_word=case["class_word"], 
                                                cond_multiplier=cm,
                                                uncond_multiplier=ucm,
                                                cfg=cfg)
                    # test clip image sim
                    clip_image_sim = self.img_to_img_similarity(pil_image_list, cur_images)
                    
                    # test clip caption sim
                    clip_text_sim = self.txt_to_img_similarity(p, cur_images)
                    
                    # face sim
                    if self.eval_face:
                        face_sim = self.face_model.face_similarity(cur_images, pil_image_list)
                    else:
                        face_sim = -1
                    
                    # save images
                    os.makedirs(os.path.join(self.output_path, f"cm{cm}_ucm{ucm}_cfg{cfg}", case["name"]), exist_ok=True)
                    for i, img in enumerate(cur_images):
                        img.save(os.path.join(self.output_path, f"cm{cm}_ucm{ucm}_cfg{cfg}", case["name"], f"{p}_{i}.png"))
                    
                    # record to dataframe
                    data = {
                        "caption": p,
                        "clip_image_sim": clip_image_sim,
                        "clip_caption_sim": clip_text_sim,
                        "face_sim": face_sim
                        }
                    df = pd.concat([df, pd.Series(data).to_frame().T], ignore_index=True)
                    
        return df
    

    @abstractmethod
    def gen_one(self, type:Optional[str],
                caption:str,
                source_group:List[Dict],
                class_word="", 
                cond_multiplier=1.,
                uncond_multiplier=0.,
                cfg=None)-> List[Image.Image]:
        """
        基于1个caption和ref_image生成图片
        return: 一个PIL.Image.Image的列表
        """
        raise NotImplementedError