
import os
import sys
import numpy as np
from PIL import Image
import pickle
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
from custom_datasets import get_dataset
from utils.art_filter import Art_filter
import torch
from matplotlib import pyplot as plt
import math
import argparse
import socket
import time
from tqdm import tqdm
import torch
def parse_args():
    parser = argparse.ArgumentParser(description="Filter the coco dataset")
    parser.add_argument("--check", action="store_true", help="Check the complete")
    parser.add_argument("--mode", default="clip_logit", help="Filter mode: clip_logit, clip_filt, caption_filt")
    parser.add_argument("--split" , default="val", help="Dataset split, val/train")
    # parser.add_argument("--start_idx", default=0, type=int, help="Start index")
    args = parser.parse_args()
    return args

def get_feat(save_path, dataloader, filter):
    clip_feat_file = save_path
    # compute_new = False
    clip_feat={}
    if os.path.exists(clip_feat_file):
        with open(clip_feat_file, 'rb') as f:
            clip_feat = pickle.load(f)
    else:
        print(f"computing clip feat",flush=True)
        clip_feature_ret = filter.clip_feature(dataloader)
        clip_feat["image_features"] = clip_feature_ret["clip_features"]
        clip_feat["ids"] = clip_feature_ret["ids"]

        with open(clip_feat_file, 'wb') as f:
            pickle.dump(clip_feat, f)
        print(f"clip_feat_result saved to {clip_feat_file}",flush=True)
    return clip_feat

def get_clip_logit(save_root, dataloader, filter):
    feat_path = os.path.join(save_root, "clip_feat.pickle")
    clip_feat = get_feat(feat_path, dataloader, filter)
    clip_logits_file = os.path.join(save_root, "clip_logits.pickle")
    # if clip_logit:
    if os.path.exists(clip_logits_file):
        with open(clip_logits_file, 'rb') as f:
            clip_logits = pickle.load(f)
    else:
        clip_logits = filter.clip_logit_by_feat(clip_feat["image_features"])
        clip_logits["ids"] = clip_feat["ids"]
        with open(clip_logits_file, 'wb') as f:
            pickle.dump(clip_logits, f)
        print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
    return clip_logits

def clip_filt(save_root, dataloader, filter):
    clip_filt_file = os.path.join(save_root, "clip_filt_result.pickle")
    if os.path.exists(clip_filt_file):
        with open(clip_filt_file, 'rb') as f:
            clip_filt_result = pickle.load(f)
    else:
        clip_logits = get_clip_logit(save_root, dataloader, filter)
        clip_filt_result = filter.clip_filt(clip_logits)
        with open(clip_filt_file, 'wb') as f:
            pickle.dump(clip_filt_result, f)
        print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
    return clip_filt_result

def caption_filt(save_root, dataloader, filter):
    caption_filt_file = os.path.join(save_root, "caption_filt_result.pickle")
    if os.path.exists(caption_filt_file):
        with open(caption_filt_file, 'rb') as f:
            caption_filt_result = pickle.load(f)
    else:
        caption_filt_result = filter.caption_filt(dataloader)
        with open(caption_filt_file, 'wb') as f:
            pickle.dump(caption_filt_result, f)
        print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
    return caption_filt_result

def gather_result(save_dir, dataloader, filter):
    all_remain_ids=[]
    all_remain_ids_train=[]
    all_remain_ids_val=[]
    all_filtered_id_num = 0

    clip_filt_result = clip_filt(save_dir, dataloader, filter)
    caption_filt_result = caption_filt(save_dir, dataloader, filter)

    caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
    all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
    remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
    remain_ids = list(remain_ids)
    remain_ids.sort()
    with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
        pickle.dump(remain_ids, f)
    print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
    return remain_ids

@torch.no_grad()
def main(args):
    filter = Art_filter()
    if args.mode == "caption_filt" or args.mode == "gather_result":
        filter.clip_filter = None
        torch.cuda.empty_cache()


    def collate_fn(examples):
        # {"image": image, "id":id}
        ret = {}
        if "image" in examples[0]:
            pixel_values = [example["image"] for example in examples]
            ret["images"] = pixel_values
        if "caption" in examples[0]:
            # prompts = [example["caption"] for example in examples]
            prompts = []
            for example in examples:
                if isinstance(example["caption"][0], list):
                    prompts.append([" ".join(example["caption"][0])])
                else:
                    prompts.append(example["caption"])
            ret["text"] = prompts
        id = [example["id"] for example in examples]
        ret["ids"] = id
        return ret
    if args.split == "val":
        dataset = get_dataset("coco_val")["val"]
    elif args.split == "train":
        dataset = get_dataset("coco_train", get_val=False)["train"]
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)

    error_files=[]



    save_root = f"coco/filt/{args.split}"
    os.makedirs(save_root, exist_ok=True)

    if args.mode == "clip_feat":
        feat_path = os.path.join(save_root, "clip_feat.pickle")
        clip_feat = get_feat(feat_path, dataloader, filter)

    if args.mode == "clip_logit":
        clip_logit = get_clip_logit(save_root, dataloader, filter)

    if args.mode == "clip_filt":
        # if os.path.exists(clip_filt_file):
        #     with open(clip_filt_file, 'rb') as f:
        #         ret = pickle.load(f)
        # else:
        clip_filt_result = clip_filt(save_root, dataloader, filter)

    if args.mode == "caption_filt":
        caption_filt_result = caption_filt(save_root, dataloader, filter)

    if args.mode == "gather_result":
        filtered_result = gather_result(save_root, dataloader, filter)

    print("finished",flush=True)
    for file in error_files:
        # os.remove(file)
        print(file,flush=True)

if __name__ == "__main__":
    args = parse_args()

    log_file = "sam_filt"
    idx=0
    hostname = socket.gethostname()
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
        idx+=1

    main(args)
    # clip_logits_analysis()


