import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import scipy.optimize

from src.data_loader import NpzDataset
from src.utils import network_name, loss_weight_name, agent_name, dat_filename, binarize, is_match, count_blocking_pairs, calc_SexEqualityCost, calc_EgalitarianCost, calc_BalanceCost
from src.models import ModelSM

from tqdm import tqdm

import os
import pathlib
import argparse

import time

class StopWatch():
    def __init__(self):
        self.start_time=None
    def tic(self):
        self.start_time = time.perf_counter()
    def toc(self):
        self.end_time = time.perf_counter()
        return (self.end_time-self.start_time)

def parse_args():
    parser = argparse.ArgumentParser()
    
    # args to identify .pth file. (required)
    parser.add_argument("model_path", type=str,
                        help='set the trained model .pth file. (required)')

    
    # args for test data loader.
    parser.add_argument("test_dat", type=str,
                        help='set test dataset dat file path. (required)')    
    

    # args to identify network structure.
    parser.add_argument("-L", "--n_layers", type=int, default = -1,
                        help='set the number of feature weaving layers. If -1, the script automatically identify all hte network structure based on the model_path. (default: -1)')
    parser.add_argument("-w", "--width1", type=int, default =64,
                        help='set the width of a layer (number of channels). (default: 64)')
    parser.add_argument("-W", "--width2", type=int, default =64,
                        help='set the second width of a layer (number of channels). This option is valid only with WeaveNet. (default: 64)')
    
    parser.add_argument("-a", "--asymmetric", action='store_true',
                        help='use asymmetric adjacent edge encoders for each side. (default: False)')
    parser.add_argument("--network_type", type=str, default='WeaveNet',
                       help='Select the network type from [MLP|GCN|WeaveNet]. (default: WeaveNet)')
    
    # args to measure the loss
    parser.add_argument('--constraint_p', dest='constraint_p', type=float, default=2.,
                        help='Value of constraint_p')
    
    
    # args for acceleration
    parser.add_argument("--num_workers", type=int, default=1,
                        help="set num_workers to data loaders. If set to -1, use maximum number of cpus. (default: 1)")
    parser.add_argument("--pin_memory", action='store_true',
                        help="set the pin_memory flag to data loaders.")
    parser.add_argument("--cudnn_benchmark",action='store_true',
                       help="set the cudnn benchmark flag.")
    parser.add_argument("--non_blocking", action='store_true',
                       help="set the non_blocking flag.")
    parser.add_argument("--device", type=int, default=-1,
                        help="set device. If -1, automatically select cpu or gpu. If -2, use cpu. (default: -1)")

    
    return parser.parse_args()

    
def main():
    args = parse_args()
    
    dist_dir = os.path.dirname(args.model_path)
    exp_name = "{}_test_{}".format(os.path.splitext(os.path.basename(args.model_path))[0],
                                   os.path.splitext(os.path.basename(args.test_dat))[0])
        
    # parse network params
    if args.n_layers <= 0:
        p = pathlib.Path(os.path.dirname(args.model_path)).parts
        #print([x for x in p])
        network_name = [x for x in p][-3]
        #print(network_name)
        # overwrite the network-related arguments
        params = network_name.split('-')
        
        args.network_type = params[0]
        args.n_layers, args.width1 = [int(v) for v in params[1:3]]
        if params[3] != 'NA':
            args.width2 = int(params[3])
        if params[4] != 'NA':
            if params[4]=='a':
                args.asymmetric=True
            else:
                args.asymmetric=False
        if params[5] != 'NA':
            if params[5]=='res':
                args.use_resnet=True
            else:
                args.use_resnet=False        
    
    
    # prepare the dataloader
    num_workers = args.num_workers
    if num_workers<=0:
        num_workers = os.cpu_count()
    test_data_dir = os.path.dirname(args.test_dat)
    dataset = NpzDataset(args.test_dat,test_data_dir)
    sab=dataset.load(os.path.join(dataset.root_dir,dataset.filelist[0]))[0]
    dl_test = DataLoader(dataset,batch_size=1,shuffle=False,num_workers=num_workers,pin_memory=args.pin_memory) 
    
    if args.device==-1:
        device = 'cpu'
        if torch.cuda.is_available():
            device = 'cuda'
    elif args.device<-1:
        device = 'cpu'
    else:
        device = 'cuda:{}'.format(args.device)
    device = torch.device(device)    
        
    checkpoints_dir,model_filename = os.path.split(args.model_path)
    model = ModelSM(
        network_type=args.network_type,
        sab_shape=sab.shape,
        L = args.n_layers,
        w1 = args.width1,
        w2 = args.width2,
        asymmetric = args.asymmetric,
        constraint_p=args.constraint_p,
        device = device,
        checkpoints_dir=checkpoints_dir
    )
    model.load_network(model_filename)                                                
    model.net.to(device)
    
    base_name=os.path.splitext(model_filename)[0]+"_"+os.path.splitext(os.path.basename(args.test_dat))[0]
    
    csv_path = os.path.join(checkpoints_dir,'{}_result.csv'.format(base_name))
    avg_path = os.path.join(checkpoints_dir,'{}_result_avg.csv'.format(base_name))
    
    
    test(dl_test, model, csv_path, avg_path,
        device,
        args.constraint_p,
        cudnn_benchmark=args.cudnn_benchmark,
        non_blocking=args.non_blocking,
        )

def is_complete_enumeration(dl_test):
    is_gs_result = True
    for i, batch in enumerate(dl_test):
        matches = batch[4][0]
        gs_matches = batch[7][0]
        if is_gs_result and (gs_matches.shape!=matches.shape or not (gs_matches == matches).all()):
            is_gs_result = False
            break
    return not is_gs_result

def get_worse_scores(gs_matches,cab,cba):
    SEqs = [int(calc_SexEqualityCost(m,cab,cba).cpu()) for m in gs_matches]
    Bals = [int(calc_BalanceCost(m,cab,cba).cpu()) for m in gs_matches]
    
    worse_SEq = max(SEqs)
    worse_Bal = max(Bals)
    return worse_SEq, worse_Bal

def eval_stable_match(m,sab,sba,cab,cba,min_SEq,min_Bal,know_optimal_solution,sms,Bals):
    SEq = int(calc_SexEqualityCost(m,cab,cba).cpu())
    Bal = int(calc_BalanceCost(m,cab,cba).cpu())

    _is_matching = False
    is_opt_SEq = False
    is_opt_Bal = False
    num_matching_duplication = 0
    num_bps = 0
    
    if is_match(m):
        _is_matching=True
        num_bps = count_blocking_pairs(sab,sba,m)
        if num_bps==0 and know_optimal_solution:
                assert(min_SEq<=SEq) # SEq cannot be less than min_SEq when m is stable.
                assert(min_Bal<=Bal)
                is_opt_SEq = int(min_SEq)==int(SEq)
                is_opt_Bal = int(min_Bal)==int(Bal)
    else:
        num_matching_duplication = int(torch.clamp(m.sum(dim=-1)-1,min=0).sum() + torch.clamp(m.sum(dim=-2)-1,min=0).sum())
        
    return SEq, Bal, _is_matching, num_bps, is_opt_SEq, is_opt_Bal, num_matching_duplication

def get_pseudo_score(score, gs_worth, is_stable):
    if is_stable:
        return score
    return gs_worth


def test(dl_test, model, csv_path, avg_path,
         device, 
         constraint_p,
         cudnn_benchmark=False,
         non_blocking=False,
        ):
    torch.backends.cudnn.benchmark = cudnn_benchmark
    model.net.eval()
    
    print("output1: ", csv_path)
    print("output2: ", avg_path)
    
    sm_col = torch.nn.Softmax(dim=-1)
    sm_row = torch.nn.Softmax(dim=-2)      
    
    know_optimal_solution = is_complete_enumeration(dl_test)
    
    
    with open(csv_path,'w') as f:
        f.write("sample_id, elasped_time, L_m, L_u, L_f, L_b, is_matching, num_matching_duplication, is_stable, num_blocking_pairs, SEq, Bal, pSEq, pBal, is_stable(w Hung), num_blocking_pairs(w Hung), SEq(w Hung), Bal(w Hung.),pSEq(w Hung), pBal(w Hung.)")
        if know_optimal_solution:
            f.write(",is_opt:SEq, is_opt:Bal,is_opt:SEq(w Hung), is_opt:Bal(w Hung.)")
        f.write("\n")

    sw = StopWatch()
    
    with torch.no_grad():
        cum_time = 0
        cum_L_m = 0
        cum_L_u = 0
        cum_L_f = 0
        cum_L_b = 0
        num_matching = 0
        num_stable = 0
        cum_SEq = 0
        cum_Bal = 0
        cum_pSEq = 0
        cum_pBal = 0
        num_opt_SEq = 0
        num_opt_Bal = 0
        total_matching_duplication = 0
        total_blocking_pairs = 0

        num_stable_hung = 0
        cum_SEq_hung = 0
        cum_Bal_hung = 0
        cum_pSEq_hung = 0
        cum_pBal_hung = 0
        num_opt_SEq_hung = 0
        num_opt_Bal_hung = 0
        total_blocking_pairs_hung = 0

        
        N=0
        
        assert(dl_test.batch_size==1)
        
        t = tqdm(dl_test)
        
        for i, batch in enumerate(t):
            sab, sba, na, nb = [b.to(device) for b in batch[:4]]
            matches, fairness, satisfaction, gs_matches, SEqs, Egals = [b[0] for b in batch[4:]]
            Bals = (SEqs+Egals)/2
            na=na[0]
            nb=nb[0]
            N += 1
            sw.tic()
            m_ = model.net.forward([sab,sba])
            m = binarize(m_, na, nb)            
            # calc scores
            elapsed_time=sw.toc()
            
            
            m_min = torch.min(sm_col(m_),sm_row(m_))                        
            cum_time += elapsed_time   
            
            # find the maximum likelihood matching by the Hungarian algorithm (O(N^3)).
            row_idx,col_idx = scipy.optimize.linear_sum_assignment(-m_min[0].cpu().numpy())
            m_hungarian = m.new_zeros(na,nb,dtype=torch.long)
            m_hungarian[row_idx,col_idx] = 1

                        
            sab=sab[0]
            sba=sba[0]
            m = m[0]
            
            m_ = m_[0]
            mc = sm_col(m_)
            mr = sm_row(m_)
            L_m = model.criterion_matrix_constraint_normed_correlation(m_,p=constraint_p).cpu()
            cum_L_m += L_m
            L_u = (model.criterion_unstability_HV(sab,sba,mc)+model.criterion_unstability_HV(sab,sba,mr)).cpu()/2
            cum_L_u += L_u            
            
            L_b = (model.criterion_satisfaction_with_fairness(mc,sab,sba)+model.criterion_satisfaction_with_fairness(mr,sab,sba)).cpu()/2
            cum_L_b += L_b
            
            L_f = (model.criterion_fairness(mc,sab,sba)+model.criterion_fairness(mr,sab,sba)).cpu()/2
            cum_L_f += L_f
            
            min_SEq = int(min(SEqs))
            min_Bal = int(min(Bals))
            cab = torch.round((nb-1)*(1-(sab-0.1)/0.9))
            cba = torch.round((na-1)*(1-(sba-0.1)/0.9))
            
            eval_result = eval_stable_match(m,sab,sba,cab,cba,min_SEq,min_Bal,know_optimal_solution,matches,Bals)
            SEq, Bal, _is_matching, num_blocking_pairs, is_opt_SEq, is_opt_Bal, num_matching_duplication = eval_result
            cum_SEq += SEq
            cum_Bal += Bal
            num_matching += int(_is_matching)
            total_matching_duplication += num_matching_duplication
            
            _is_stable=(num_blocking_pairs==0 and _is_matching)
            num_stable += int(_is_stable)
            total_blocking_pairs += num_blocking_pairs
            
            worse_SEq, worse_Bal = get_worse_scores(gs_matches,cab,cba)
            pSEq = get_pseudo_score(SEq, worse_SEq, _is_stable)
            cum_pSEq += pSEq
            pBal = get_pseudo_score(Bal, worse_Bal, _is_stable)
            cum_pBal += pBal
            
            if know_optimal_solution:
                if is_opt_SEq:
                    num_opt_SEq += 1 
                if is_opt_Bal:
                    num_opt_Bal += 1 
                    
            eval_result_hung = eval_stable_match(m_hungarian,sab,sba,cab,cba,min_SEq,min_Bal,know_optimal_solution,matches,Bals)
            SEq_hung, Bal_hung, _is_matching_hung, num_blocking_pairs_hung, is_opt_SEq_hung, is_opt_Bal_hung, num_matching_duplication_hung = eval_result_hung
            assert(_is_matching_hung)
            assert(num_matching_duplication_hung==0)
            cum_SEq_hung += SEq_hung
            cum_Bal_hung += Bal_hung
            _is_stable_hung = (num_blocking_pairs_hung==0)
            num_stable_hung += int(_is_stable_hung)
            total_blocking_pairs_hung += num_blocking_pairs_hung
            
            pSEq_hung = get_pseudo_score(SEq_hung, worse_SEq, _is_stable)
            cum_pSEq_hung += pSEq_hung
            pBal_hung = get_pseudo_score(Bal_hung, worse_Bal, _is_stable)
            cum_pBal_hung += pBal_hung

            if know_optimal_solution:
                if is_opt_SEq_hung:
                    num_opt_SEq_hung += 1 
                if is_opt_Bal_hung:
                    num_opt_Bal_hung += 1 
            
            
            with open(csv_path,'a') as f:
                #f.write("sample_id, elasped_time, L_m, L_u, L_f, L_b, is_matching, num_matching_duplication, \\
                #is_stable, SEq, Bal, is_stable(hung), SEq(hung), Bal(hung)")        
                #if know_optimal_solution:
                #    f.write(",is_opt:SEq, is_opt:Bal,is_opt:SEq(hung), is_opt:Bal(hung)")
                #f.write("\n")
                
                f.write(", ".join([str(i), str(float(elapsed_time)),                                   
                                  str(float(L_m)),
                                  str(float(L_u)), 
                                  str(float(L_f)), 
                                  str(float(L_b)), 
                                  str(_is_matching), 
                                  str(int(num_matching_duplication)),
                                  str(_is_stable), 
                                  str(int(num_blocking_pairs)),
                                  str(int(SEq)), 
                                  str(int(Bal)), 
                                  str(int(pSEq)), 
                                  str(int(pBal)), 
                                  str(_is_stable_hung), 
                                  str(int(num_blocking_pairs_hung)),
                                  str(int(SEq_hung)), 
                                  str(int(Bal_hung)), 
                                  str(int(pSEq_hung)), 
                                  str(int(pBal_hung)), 
                                 ]))
                if know_optimal_solution:
                    f.write(", "+", ".join([                        
                        str(int(is_opt_SEq)),
                        str(int(is_opt_Bal)),
                        str(int(is_opt_SEq_hung)),
                        str(int(is_opt_Bal_hung)),
                    ]))
                f.write("\n") 

        epsilon = 0.00000001
        with open(avg_path, 'w') as f:
            f.write("avg. elasped_time, L_m, L_u, L_f, L_b, matching success rate, avg matching duplication, stable matching success rate, avg # of blocking pairs, avg. SEq, avg. Bal, stable matching success rate (w Hung.), avg # of blocking pairs (w Hung.), avg. SEq(w Hung.), avg. Bal(w Hung.)")
            if know_optimal_solution:
                f.write(",is_opt_SEq, is_opt_Bal, is_opt_SEq(w Hung.), is_opt_Bal(w Hung.),")
            f.write("\n")
           
            if num_matching==0:
                avg_blocking_pairs = 0
            else:
                avg_blocking_pairs = total_blocking_pairs / num_matching
            values = [cum_time, cum_L_m, cum_L_u, cum_L_f, cum_L_b, #5
                      num_matching, total_matching_duplication, num_stable, avg_blocking_pairs*N, cum_SEq, cum_Bal, #10
                      num_stable_hung,total_blocking_pairs_hung, cum_SEq_hung, cum_Bal_hung]
            values = ["{:.6f}".format(v/N) for v in values]
            
            
            f.write(", ".join(values))
            
            values = [num_opt_SEq,num_opt_Bal,num_opt_SEq_hung,num_opt_Bal_hung]
            values = ["{:.6f}".format(v/N) for v in values]
            if know_optimal_solution:
                f.write(", {}, {}, {}, {}".format(*values))                
            f.write("\n") 
  
                     

if __name__=='__main__':
    main()
