import clip
import numpy as np
import pandas as pd
import torch
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import easyocr
from PIL import Image
import faiss
import logging
import requests
import os
import json
import argparse

def download_fig(url):
	response = requests.get(url)
	response.raise_for_status()
	return response.content

def retrieval(ind, text, k=1000):
	text_tokens = clip.tokenize([text], truncate=True)
	text_features = model.encode_text(text_tokens.to(device))
	text_features /= text_features.norm(dim=-1, keepdim=True)
	text_embeddings = text_features.cpu().detach().numpy().astype('float32')
	D, I = ind.search(text_embeddings, k)
	return D[0], I[0]

def ocr(img, reader):
	result = reader.readtext(img, detail = 0)
	total = 0
	for s in result:
		total += len(s.split(" "))
	return result, total

def write_json(meta, save_path):
	for i, value in enumerate(meta):
		meta[i]['text_similarity'] = float(meta[i]['text_similarity'])
	with open(save_path, 'w') as f:
		json.dump(meta, f)

def go_slice(length_meta, Is):
	result = {}
	intervals = [] 
	Is_idx = 0 
	Is_len = len(Is)  
	start = 0
	for path, length in length_meta.items():
		end =  length  
		intervals = []
		while Is_idx < Is_len and Is[Is_idx] < end:
			if Is[Is_idx] >= start:
				intervals.append(Is[Is_idx] - start)
			Is_idx += 1
		if intervals:
			result[path] = intervals.copy() 
		start = length
	return result

if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument('--par_path', type=str, default="/mnt/petrelfs/zhangyongting/zyt/research/RLHF/index/embeddings",help='parquet path')
	parser.add_argument('--index_path', type=str, default="/mnt/petrelfs/zhangyongting/zyt/research/RLHF/index/index_files/",help='index path')
	parser.add_argument('--image_path', type=str, default="/mnt/petrelfs/zhangyongting/zyt/research/RLHF/test_img_par", help='output image path')
	parser.add_argument('--length_meta', type=str, default="/mnt/petrelfs/zhangyongting/zyt/research/RLHF/index/clip-retrieval/notebook/meta_parquet.json", help='output image path')
	parser.add_argument('--obj', type=int, default=5, help='object number')
	parser.add_argument('--text_th', type=int, default=5, help='threshold for text number')
	parser.add_argument('--par_end', type=int, default=50, help='load parquet end number')
	parser.add_argument('--par_start', type=int, default=0, help='load parquet start number')
	parser.add_argument('--index_end', type=int, default=1, help='load index end number')
	parser.add_argument('--index_start', type=int, default=0, help='load index start number')
	parser.add_argument('--meta_start', type=int, default=0, help='meta start')
	parser.add_argument('--recall_number', type=int, default=1000, help='load index start number')
	parser.add_argument('--par_block_number', type=int, default=2, help='load parquet block number')
	parser.add_argument('--search_th', type=float, default=0.20, help='threshold for search result')
	parser.add_argument('--search_key', type=str, help='search text')
	args = parser.parse_args()

	if not os.path.exists(args.image_path):
		os.makedirs(args.image_path)
	logging.basicConfig(filename= os.path.join(args.image_path, "log.txt"),
					format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s-%(funcName)s',
					level=logging.INFO)
	json_path = os.path.join(args.image_path, "meta0.json")

	device = "cuda" if torch.cuda.is_available() else "cpu"
	model, preprocess = clip.load("ViT-L/14", device=device)
	reader = easyocr.Reader(['en', 'fr', 'es', 'de'], gpu = True if device=="cuda" else False)
	logging.info("Finish loading ocr and vit model")

	data_dir = Path(args.par_path)
	data_dir = list(data_dir.glob('*.parquet'))
	data_dir = sorted(data_dir, key = lambda img : int(img.stem[-4:]))

	Ds = []
	Is = []
	print("Indexing")
	with tqdm(total=(min(args.index_end, 55) - args.index_start)) as pbar:
		for k in range(args.index_start, min(args.index_end + 1, 55)):
			par_begin = 0
			logging.info('Load index {}'.format(k))
			if k < 10:
				ind = faiss.read_index(args.index_path + "knn.index0" + str(k))
			else:
				ind = faiss.read_index(args.index_path + "knn.index" + str(k))
			D, I = retrieval(ind, args.search_key, args.recall_number)
			Ds.extend(D)
			Is.extend(I)
			pbar.update(1)
	#import pdb; pdb.set_trace()
	Is = np.array(Is)[np.array(Ds) > args.search_th]
	Is = np.sort(Is)
	with open(args.length_meta, 'r') as f:
		length_meta = json.load(f)
	slice_result = go_slice(length_meta, Is)	
	c = 0
	if args.meta_start != 0:
		with open(json_path, 'r') as f:
			meta = json.load(f)
		lut = [result['md5'] for result in meta]
	else:
		meta = []
		lut = []
	print("Retriving")

	
	# print(slice_result)
	with tqdm(total=args.obj) as pbar:
		for path, results in slice_result.items():
			if c > args.obj:
				logging.info('Finish at parquet' + path)
				break
			logging.info('Load parquet' + path)
			df = pd.read_parquet(path)
			for i in results:
				if c > args.obj:
					logging.info('Finish at parquet' + path)
					break
				row = df.iloc[i]
				# Limit too small picture
				if row['width'] + row['height'] < 350:
					continue
				suffix = row['url'].split('.')[-1]
				if suffix not in ['jpg', 'png', 'jpeg']:
					continue
				file_name = str(c + args.meta_start) + "." + suffix
				save_path = os.path.join(args.image_path, file_name)
				
				image = {
					"save_path": save_path, 
					"width": int(row['width']), 
					"height": int(row['height']), 
					'search_key': args.search_key,
					"url": row['url'], 
					"md5": row['md5'], 
					"caption_laion": row['caption'], 
					"text_similarity": row['similarity'], 
					"image_key": row['key']}
				if image['md5'] in lut:
					continue
				if "pics" in row['url'] or "https://data.whicdn.com" in row["url"] or "http://www.quickmeme.com" in row["url"]:
						continue
				try:
					img = download_fig(row['url'])
					result, total = ocr(img, reader)
					if total > args.text_th:
						continue
					image['text_count'] = total
					with open(save_path, 'wb') as file:
						file.write(img)
					logging.info('Finish write image {}'.format(c))
					meta.append(image)
					c += 1
					if c % 10 == 0:
						write_json(meta, json_path)
					pbar.update(1)
				except Exception as e:
					logging.error(e)
					logging.error("Error downloading figure {} failed".format(row['url']))
					continue
		write_json(meta, json_path)



	




