import os
import pickle
import json
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import csv
import argparse
import pandas as pd

def get_product_features(product_feature_path, product_num_record_path):
    feature_paths = os.listdir(product_feature_path)
    max_num = pd.read_csv(product_num_record_path)['img_num'].max()
    product_num = len(feature_paths)
    with open(os.path.join(product_feature_path, feature_paths[0]), 'rb') as f:
        feature0 = pickle.load(f)
    feature_dim = feature0.shape[-1]
    product_features = np.zeros([max_num, product_num, feature_dim])
    max_item = np.zeros([max_num, product_num])

    for i, feature_path in enumerate(feature_paths):
        with open(os.path.join(product_feature_path, feature_path), 'rb') as f:
            feature =pickle.load(f)
        img_num = feature.shape[0]
        product_features[:img_num, i, :] = feature
        max_item[img_num:, i] = 1
    
    return product_features, max_item

def load_multiple_batch(current_load_num, load_batch_num, load_num, raw_item_path):
    load_features = []
    if (current_load_num * load_batch_num)<load_num-1:
        begin_load_index = current_load_num*load_batch_num
        end_load_index = current_load_num*load_batch_num + load_batch_num
        for l in range(begin_load_index, end_load_index):
            feature_path = os.path.join(raw_item_path, str(l).zfill(7)+'.pkl')
            with open(feature_path, 'rb') as f:
                load_features.append(pickle.load(f))
    else:
        last_batch_num = load_num % load_batch_num
        begin_load_index = current_load_num * load_batch_num
        end_load_index = current_load_num*load_batch_num + last_batch_num
        for l in range(begin_load_index, end_load_index):
            feature_path = os.path.join(raw_item_path, str(l).zfill(7)+'.pkl')
            with open(feature_path, 'rb') as f:
                load_features.append(pickle.load(f))
    return np.concatenate(load_features, 0)

def get_load_batch_sim(product_features, max_sim_assit_item, recording):
    product_features = torch.Tensor(product_features)
    max_sim_assit_item = torch.Tensor(max_sim_assit_item)
    recording = torch.Tensor(recording)
    product_features = torch.nn.functional.normalize(product_features, dim=-1)
    recording = torch.nn.functional.normalize(recording, dim=-1)

    max_num, feature_num, dim = product_features.shape
    sim = torch.matmul(product_features.reshape(-1, dim), recording.permute(-1, -2))
    sim = sim.reshape(max_num, feature_num, -1)
    max_sim = torch.max(sim - 10000*max_sim_assit_item.unsqueeze(-1), 0)[0]
    return max_sim.cpu().numpy()


def split_batch_to_single(batch_size, 
                          load_batch_num, 
                          item_nums, 
                          item_names, 
                          product_features, 
                          max_sim_assit_item, 
                          raw_item_path, 
                          target_item_path):
    total_item_num = len(item_nums)
    begin_indexes = [None] * total_item_num
    end_indexes = [None] * total_item_num
    begin_index = 0
    for i, item_num in enumerate(item_nums):
        begin_indexes[i] = begin_index
        end_index = begin_index + item_nums[i]
        end_indexes[i] = end_index
        begin_index = end_index


    load_num = len(os.listdir(raw_item_path))
    record_num = 0
    recording = load_multiple_batch(record_num, load_batch_num, load_num, raw_item_path)
    max_sim = get_load_batch_sim(product_features, max_sim_assit_item, recording)

    record_num = 0
    max_sims = [None] * total_item_num
    max_temp = []
    for i, item_name in enumerate(tqdm(item_names)):
        begin_index = begin_indexes[i]
        end_index = end_indexes[i]
        
        record_index_of_begin_index = begin_index // (batch_size * load_batch_num)
        record_pos_of_begin_index = begin_index % (batch_size * load_batch_num)
        record_index_of_end_index = end_index // (batch_size * load_batch_num)
        record_pos_of_end_index = end_index % (batch_size * load_batch_num)

        if record_index_of_begin_index==record_num:
            if record_index_of_end_index==record_num:
                item_max_sim = max_sim[:, record_pos_of_begin_index:record_pos_of_end_index]
                if record_pos_of_end_index==((batch_size * load_batch_num)-1):
                    record_num = record_num + 1
                    recording = load_multiple_batch(record_num, load_batch_num, load_num, raw_item_path)
                    max_sim = get_load_batch_sim(product_features, max_sim_assit_item, recording)
            elif record_index_of_end_index>record_num:
                max_temp.append(max_sim[:, record_pos_of_begin_index:])
                record_num = record_num + 1
                for j in range(record_num, record_index_of_end_index+1):
                    recording = load_multiple_batch(record_num, load_batch_num, load_num, raw_item_path)
                    max_sim = get_load_batch_sim(product_features, max_sim_assit_item, recording)
                    if j==record_index_of_end_index:
                        max_temp.append(max_sim[:, :record_pos_of_end_index])
                    elif j<record_index_of_end_index:
                        max_temp.append(max_sim)
                        record_num = record_num + 1
                item_max_sim = np.concatenate(max_temp, -1)
                max_temp = []


        if item_max_sim.shape[-1]==0:
            max_sims[i] = item_max_sim
        else:
            max_sims[i] = np.max(item_max_sim, -1)[:, np.newaxis]
    return np.concatenate(max_sims, -1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate and store top 1 similarities between product and patent features.")
    parser.add_argument('--product_features', type=str, required=True, help="Path to the folder containing product features")
    parser.add_argument('--product_num_record_path', type=str, required=True, help="Path to the folder containing product features")
    parser.add_argument('--patent_features', type=str, required=True, help="Path to the folder containing patent features")
    parser.add_argument('--patent_num_record_path', type=str, default='./jsons/a63_products_gt.json', help="Path to the ground truth file")
    parser.add_argument('--save_csv', type=str, default='.', help="Path to the file of saving csv file.")
    parser.add_argument('--batch_size', default=256, type=int, help='Number of patents to process in parallel')
    parser.add_argument('--load_batch_num', default=1, type=int, help='Number of patents to process in parallel')

    args = parser.parse_args()
    '''
    Because the patents in patent retrieval pool is huge and each patent may have different images, 
    we adopt a slide-window-like method of fixed batch size to obtain similarity matrix fast and then split product-patent similarity from the matrix. 
    Specifically, we first load the product feature. 
    Then, we slide the whole patent retrieval with fixed batch size to compare product features with patent features, and obtain image-level similarity matrix.  
    Finally, we split product-patent similarity scores from image-level similarity matrix. 
    '''
    # Obtain product feature
    product_features, max_sim_assit_item = get_product_features(args.product_features, args.product_num_record_path)
    # Slide the patent retrieval pool and split product-patent similarity scores
    item_num_recordings = pd.read_csv(args.patent_num_record_path)
    max_sim = split_batch_to_single(args.batch_size, 
                          args.load_batch_num, 
                          list(item_num_recordings['img_num']), 
                          list(item_num_recordings['Index']), 
                          product_features, 
                          max_sim_assit_item, 
                          args.patent_features, 
                          args.save_csv)

    with open(os.path.join(args.save_csv, 'max_sim.pkl'), 'wb') as f:
        pickle.dump(max_sim, f)