from collections import defaultdict
import os.path
import numpy as np
import csv
import bisect
import sys


def my_in(L,val):
    idx = bisect.bisect_right(L,val)
    if idx==0 or L[idx-1]!=val:
        return False
    return True

def get_freq_rating_pairs(src_dir):
    male_count = 0
    female_count = 0
    unknown_count = 0
    gender_idx = []
    gender_list = []
    uid2idx = {}
    
    edges = defaultdict(lambda: [-1,-1],{})


    gender_filename = os.path.join(src_dir,'gender.dat')
    with open(gender_filename,'r') as f:
        fcsv = csv.reader(f)
        for row in fcsv:
            uid, gender = row
            uid = int(uid)
            if gender not in ['M','F']:
                continue


            gender_list.append(gender)
            uid2idx[uid] = len(gender_idx)
            if gender=='M':
                gender_idx.append(male_count)
                male_count+=1
            elif gender=='F':
                gender_idx.append(female_count)
                female_count+=1
            else:
                gender_idx.append(unknown_count)
                unknown_count+=1

    num_m = male_count
    num_f = female_count
    print("num:", num_m, num_f, unknown_count)    

    rating_filename = os.path.join(src_dir,'ratings.dat')

    uid_entries = sorted(uid2idx.keys())
    with open(rating_filename,'r') as f:
        fcsv = csv.reader(f)
        for row in fcsv:
            uid_from, uid_to, rate = row
            uid_from = int(uid_from)
            uid_to = int(uid_to)
            rate = int(rate)

            if not (my_in(uid_entries,uid_from) and my_in(uid_entries,uid_to)):
                continue                   

            idx_from = uid2idx[uid_from]
            idx_to = uid2idx[uid_to]
            # uid: original id in the file
            # idx: index of uid in gender_list, gender_idx
            gender_from = gender_list[idx_from]
            gender_to = gender_list[idx_to]


            if gender_from==gender_to:
                continue
            i = gender_idx[idx_from]
            j = gender_idx[idx_to]
            # i,j: index in Sab/Sba
            if gender_from=='M':
                edges[(i,j)][0] = rate
            else:
                edges[(j,i)][1] = rate
    freq_rating_pair = np.zeros((10,10))
    for k,v in edges.items():
        if v[0]<1 or v[1]<1:
            continue
        freq_rating_pair[v[0],v[1]] += 1
        
    return freq_rating_pair


if __name__=='__main__':
    args = sys.argv
    src_dir = "datasets/libimseti"
    if len(args)>1:
        src_dir = args[1]
    freq_rating_pair = get_freq_rating_pairs(src_dir)
    print(repr(freq_rating_pair))

    