import json
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import argparse
import os
import pickle
import h5py
import re
from tqdm import tqdm
import ast

def obtain_top_k_sorted_of_topK2(k, all_sim_embeds, topK_patents, index_transfer, output_dir, batch_size, device):
    h, w = topK_patents.shape
    topK_patents = torch.Tensor(topK_patents).long()
    topK_embeds = torch.zeros([h, w])
    for i in range(h):
        sim_embedi = all_sim_embeds[i, ]
        ood_index = index_transfer.iloc[i, ]['index_of_OOD']
        topK_embeds[i, ] = torch.Tensor(sim_embedi)[topK_patents[ood_index]]
    all_sim_embeds = topK_embeds
    # Set device
    device = torch.device(device)
    # Initialize sorter
    sorter = data_sorter().to(device)
    # Get file length
    file_num = all_sim_embeds.shape[0]
    batch_num = file_num // batch_size
    # Initialize output 
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    top_k_sorted_x = np.zeros([file_num, k])
    top_k_sorted_indices = np.zeros([file_num, k])        
    # Begin sort
    for batchi in tqdm(range(batch_num + 1)):
        # Read batch records
        init_id = batchi * batch_size
        if batchi!=batch_num:
            end_id = (batchi + 1) * batch_size
        else:
            end_id = file_num
        recordi = np.array(all_sim_embeds[init_id:end_id, ])
        recordi = torch.Tensor(recordi).to(device)
        # Sort
        sorted_x, sorted_indices = sorter(recordi)
        # Select top k
        sorted_x = sorted_x.cpu().numpy()
        sorted_indices = sorted_indices.cpu().numpy()
        top_k_sorted_x[init_id:end_id, ] = sorted_x[:, :k]
        top_k_sorted_indices[init_id:end_id, ] = sorted_indices[:, :k]
        # Empty GPU Memory
        torch.cuda.empty_cache()
    top_k_sorted_indices = torch.Tensor(top_k_sorted_indices)
    for i in range(h):
        ood_index = index_transfer.iloc[i, ]['index_of_OOD']
        top_k_sorted_indices[i] = torch.Tensor(topK_patents[ood_index])[top_k_sorted_indices[i].long()]
    top_k_sorted_indices = top_k_sorted_indices.numpy()
    
    top_k_sorted_x_path = os.path.join(output_dir, 'top_%d_sorted_x.npy'%k)
    top_k_sorted_indices_path = os.path.join(output_dir, 'top_%d_sorted_indices.npy'%k)
    np.save(top_k_sorted_x_path, top_k_sorted_x)
    np.save(top_k_sorted_indices_path, top_k_sorted_indices)

    return top_k_sorted_x, top_k_sorted_indices

def obtain_patent_id(label_list, index, result_dir):
    # Initialize stored dataframe
    label_list = np.array(label_list)
    file_num, k = index.shape
    patent_names = pd.DataFrame(index=range(file_num), columns=range(k))
    for i in range(file_num):
        patenti = label_list[np.int64(index[i])]
        patent_names.iloc[i, ]=patenti

    # Save
    save_path = os.path.join(result_dir, 'top_'+str(k)+'_patent_id.h5')
    patent_names.to_csv(save_path)
    return patent_names


def evaluate_new(GT_path, patent_ids, model_name):
    correct = 0
    file_num = pat_ids.shape[0]
    df = pd.read_csv(GT_path)
    
    for i in tqdm(range(file_num)):
        patent_code = df.iloc[i]['Patent']
        patent_code = ast.literal_eval(patent_code)
        patent_code = list(set(patent_code))
        patent_code2 = list(set(list(patent_ids.iloc[i])))
        patent_code_all = list(set(patent_code + patent_code2))
        if len(patent_code_all) < (len(patent_code) + len(patent_code2)):
            correct = correct + 1

    k = patent_ids.shape[-1]
    print(model_name, ' Top ', k, ' Correct Num: ', correct)
    print(model_name, ' Top ', k, ' Accuracy:', correct/439)
    return correct

class data_sorter(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.sort(x, dim=1, descending=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate top K recall and average rank from similarity matrix.")
    parser.add_argument('--simi_file', type=str, required=True, help="Path to the similarity CSV file")
    parser.add_argument('--patent_dict_path', type=str, required=True, help="Path to the patent index")
    parser.add_argument('--gt_path', type=str, required=True, help="Path to the gt file")
    parser.add_argument('--batch_size', default=256, type=int, help='Number of patents to process in parallel')
    parser.add_argument('--output_dir', default='./output', type=str, help='Output dir')
    parser.add_argument('--topK_patents_file', default=None, type=str, help='Number of patents to process in parallel')
    parser.add_argument('--transfer_index', default=None, type=str, help='Number of patent index. ')

    args = parser.parse_args()
    device = "cuda:0"
    os.makedirs(args.output_dir, exist_ok=True)
    # simi matrix
    with open(args.simi_file, 'rb') as f:
        simi = pickle.load(f)
    # patent position dict
    patent_record = pd.read_csv(args.patent_dict_path)
    patent_list = list(patent_record['Index'])

    index_transfer = pd.read_csv(args.transfer_index)
    topK_patents = np.load(args.topK_patents_file)

    top_k_values = [100, 500, 1000, 2000]
    recalls = {top_k: [] for top_k in top_k_values}

    top_k_sorted_x, top_k_sorted_indices = obtain_top_k_sorted_of_topK2(K, simi, topK_patents, index_transfer, args.output_dir, args.batch_size, device)
    
    patent_ids = obtain_patent_id(patent_list, top_k_sorted_indices, args.output_dir)

    for k in top_k_values:
        model_name = (args.simi_file).split('/')[-2]
        correct = evaluate_new(args.gt_path, patent_ids.iloc[:, :k], model_name)