#!/usr/bin/env python
# coding: utf-8

import os
import numpy as np
import time
import math
import sys
import random

from joblib import Parallel, delayed

import torch
from torch.utils.data import DataLoader

from attrdict import AttrDict

from src.data_loader import UniversalSMIGenerator
import argparse

from src.utils import calc_fairness, calc_satisfaction, enumerate_stable_match_bf, calc_fairness_np, calc_satisfaction_np
from src.match_gs_algo import Matcher as GSAlgo


parser = argparse.ArgumentParser(description='validation_dataset_generator')
parser.add_argument('-N', type=int, default=10,
                        help='Number of agents in each group.')
parser.add_argument('-s', '--seed', type=int, default=None,
                        help='seed for sample generation.')
parser.add_argument('--n_samples', type=int, default=1000,
                        help='Number of samples the script generates.')
parser.add_argument('--use_continuous_preference',action='store_true',
                        help='Use continuous preference scores instead of preference rank. default: False')
parser.add_argument('--distrib_m', type=str, default='U',
                        help='Distribution of male group. [U:uniform, D:discrete, G:gaussian] default: U')
parser.add_argument('--distrib_f', type=str, default='U',
                        help='Distribution of female group. [U:uniform, D:discrete, G:gaussian] default: U')
parser.add_argument('--dist_dir', type=str, default='validation',
                        help='distination sub-directory under ./datasets. default: validation')
parser.add_argument('--sigma_m', type=float, default=0.4,
                        help='The parameter of the distributions. default: 0.4')
parser.add_argument('--sigma_w', type=float, default=0.4,
                        help='The parameter of the distributions. default: 0.4')
parser.add_argument('--n_jobs', type=int, default=-1,
                        help='The number of cpus that used for generation. default: -1 (use all cpus)')
parser.add_argument('--brute_force',action='store_true',default=False,help='force to enumerate stable matching by brute force')
opt = parser.parse_args()
BatchSize=1 #fixed!

if opt.seed is None:
    # fix seed to reproduce the dataset used in the paper.
    if opt.dist_dir == 'validation':
        seed = 135789 # the same seed in the paper
    elif opt.dist_dir == 'test':
        seed = 2456 # the same seed in the paper    
    else:
        # the same directory name yields the same random seed.
        seed = sum(int(chr(s)) for s in opt.dist_dir[:-1])*int(chr(opt.dist_dir[-1]))
else:
    seed = opt.seed
    
print("seed: ", seed)
    
np.random.seed(seed=seed)
torch.manual_seed(seed)
print(opt.N)
ds = UniversalSMIGenerator(opt.distrib_m,opt.distrib_f,opt.sigma_m,opt.sigma_w,N_range_m=(opt.N,opt.N),N_range_w=(opt.N,opt.N),batch_size=BatchSize,len=opt.n_samples,transform=not opt.use_continuous_preference)
    
dl = DataLoader(ds, batch_size=BatchSize)



def get_GS(sab,sba):
    m_m = GSAlgo(sab.numpy(),sba.numpy()).match()
    m_f = GSAlgo(sba.numpy(),sab.numpy()).match().t()
    sms = [m_m,m_f]
    return sms

sms_enumerate_func = enumerate_stable_match_bf
alg = 'Enumeration'
if not opt.brute_force:
    sms_enumerate_func = get_GS
    alg = 'GS Algorithm'
    
    

def save1sample(sab,sba,filepath,verbose=False):
    N = sab.shape[0]
    matches = sms_enumerate_func(sab,sba)
    fairness = np.array([calc_fairness(m,sab,sba) for m in matches]) * N
    satisfaction = np.array([calc_satisfaction(m,sab,sba) for m in matches]) * N
    matches = np.array([m.numpy() for m in matches])
    cab = np.round((N-1)*(1-(sab.numpy()-0.1)/0.9))
    cba = np.round((N-1)*(1-(sba.numpy()-0.1)/0.9))
    fairnesses_c = np.array([calc_fairness_np(sm,cab,cba)*N for sm in matches])
    satisfactions_c = np.array([calc_satisfaction_np(sm,cab,cba)*N for sm in matches])
    if alg == 'GS Algorithm':
        gs_matches = matches
    else:        
        gs_matches = np.array([m.numpy() for m in get_GS(sab,sba)])
    np.savez(filepath,
             sab=sab,
             sba=sba,
             matches=matches,
             fairness=fairness,
             satisfaction=satisfaction,
             gs_matches=gs_matches,
             cab=cab,
             cba=cba,
             SexEqualityCost=fairnesses_c,
             EgalitarianCost=satisfactions_c,
             algorithm=alg,            
            )
    if verbose:
        print("done: ", filepath)


dist_dir = os.path.join("datasets",opt.dist_dir,"{}{}".format(opt.distrib_m,opt.distrib_f),"size-{:02d}".format(opt.N))
print(dist_dir)
os.makedirs(dist_dir, exist_ok=True)

def get_path(i):
    return os.path.join(dist_dir,"instance_size-{:02d}_{:04d}".format(opt.N,i))

print("start to generate samples...")
it = iter(dl)

# debug save1sample
#sab,sba,_,_ = next(it)
#save1sample(sab[0],sba[0],get_path(0),True)

r = Parallel(n_jobs=opt.n_jobs)(
    [delayed(save1sample)(sab[0],sba[0],get_path(i), i%100==99) for i,(sab,sba,_,_) in enumerate(it)]
)
print("... finished.")




