import os
import gc
import csv
import glob
import torch
import pickle
import argparse
from collections import defaultdict
from tqdm import tqdm
import pandas as pd
from itertools import chain
import numpy as np


def pickle_load(path):
    with open(path, 'rb') as f:
        obj = pickle.load(f)
    return obj


def load_qid2pos(file_path):
    qid2pos = defaultdict(list)
    with open(file_path, "r", encoding="utf-8") as fi:
        for idx, line in enumerate(fi):
            if idx == 0:continue
            qid, posid, score = line.strip().split("\t")
            
            if int(score) <= 0:
                continue
                
            if posid not in qid2pos[qid]:
                qid2pos[qid].append(posid)
    return qid2pos


def get_qry_pos_tensor(qid2pos, q_reps, q_lookup, p_reps, look_up):
    qry_list = []
    pos_list = []
    for idx, qid in enumerate(q_lookup):
        posid_list = qid2pos[qid]
        
        for posid in posid_list:
            if posid not in look_up:
                continue
            
            ## valid qry
            qry_emb = q_reps[idx]
            qry_list.append(qry_emb)

            ## valid pos
            pos_idx = look_up.index(posid)
            pos_emb = p_reps[pos_idx]
            pos_list.append(pos_emb)
        
    qry_tensor = torch.tensor(np.array(qry_list))
    pos_tensor = torch.tensor(np.array(pos_list))
    
    # qry_tensor = qry_tensor.cuda()
    # pos_tensor = pos_tensor.cuda()
        
    return qry_tensor, pos_tensor


def get_align_loss(x, y, alpha=2):
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()

def get_uniform_loss_standard(x, y, t=2):
    tot = torch.cat([x, y], dim=0)
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()



if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name')
    parser.add_argument('--result_dir')
    parser.add_argument('--qrels_path')
    
    
    args = parser.parse_args()
    
    ## ********************************************
    ## load qry
    query_reps = os.path.join(args.result_dir, args.dataset_name, "query/qry.pt")
    q_reps, q_lookup = pickle_load(query_reps)
    
    ## ********************************************
    ## load psg
    passage_reps = os.path.join(args.result_dir, args.dataset_name, "corpus/*")
    index_files = glob.glob(passage_reps)
    
    p_reps_0, p_lookup_0 = pickle_load(index_files[0])
    shards = chain([(p_reps_0, p_lookup_0)], map(pickle_load, index_files[1:]))
    shards = tqdm(shards, desc='Loading shards into index', total=len(index_files))
     
    p_reps = []
    look_up = []
    for _p_reps, p_lookup in shards:
        p_reps.append(_p_reps)
        # look_up += p_lookup
        look_up.extend(p_lookup) ## modifed for beir docid not int is str
        
    p_reps = np.concatenate(p_reps, axis=1)

    ## ********************************************
    ## load qrels
    qid2pos = load_qid2pos(args.qrels_path)
    
    ## ********************************************
    align_results = []
    uniform_results = []    
    assert len(q_reps) == len(p_reps)
    for idx in tqdm(range(len(q_reps))):
        ## get qry, pos tensor
        qry_tensor, pos_tensor = get_qry_pos_tensor(
            qid2pos=qid2pos, 
            q_reps=q_reps[idx], 
            q_lookup=q_lookup, 
            p_reps=p_reps[idx], 
            look_up=look_up, 
        )
        align_loss = get_align_loss(qry_tensor, pos_tensor)
        
        uniform_loss_standard = get_uniform_loss_standard(
            torch.tensor(q_reps[idx]),
            torch.tensor(p_reps[idx])
        )

        # uniform_loss_standard = get_uniform_loss_standard(
        #     torch.tensor(q_reps[idx]).cuda(),
        #     torch.tensor(p_reps[idx]).cuda()
        # )

        align_results.append('{:.3f}'.format(align_loss.item()))
        uniform_results.append('{:.3f}'.format(uniform_loss_standard.item()))
        
    
    ## ********************************************
    ## write this sub-res
    sub_log_path = os.path.join(args.result_dir, args.dataset_name, "{}.tsv".format(args.dataset_name))
    sub_table = {}
    if os.path.exists(sub_log_path):
        sub_table = pd.read_csv(sub_log_path)
        
    sub_table["alignment"] = align_results
    sub_table["uniformity"] = uniform_results
    pd.DataFrame(sub_table).to_csv(sub_log_path, mode='w',index=False)
    

    ## ****************************************************

    ## align
    align_table = {}
    align_log_path = os.path.join(args.result_dir, "alignmet.csv")
    if os.path.exists(align_log_path):
        align_table = pd.read_csv(align_log_path)
        
    align_table[args.dataset_name] = align_results
    pd.DataFrame(align_table).to_csv(align_log_path, mode='w',index=False)
    
    ## uniform
    uniform_table = {}
    uniform_log_path = os.path.join(args.result_dir, "uniformity.csv")
    if os.path.exists(uniform_log_path):
        uniform_table = pd.read_csv(uniform_log_path)
        
    uniform_table[args.dataset_name] = uniform_results
    pd.DataFrame(uniform_table).to_csv(uniform_log_path, mode='w',index=False)
    
            