import pandas as pd
import os
from testing_utils import *


def wasserstein_2d(A, B):
    A_x, A_y = A[:, 0], A[:, 1]
    B_x, B_y = B[:, 0], B[:, 1]
    
    
    distance_x = wasserstein_distance(A_x, B_x)
    distance_y = wasserstein_distance(A_y, B_y)

    return np.sqrt(distance_x**2 + distance_y**2)

def avg_err(chull, ptsets):
    errs = []
    for i in range(len(chull)):
        points = ptsets[i]
    
        errs.append(directional_err(chull[i], points, 2, n=1000))
        
    return np.mean(errs)

def avg_wasserstein(chull, gt_hulls):
    distances = []
    for i in range(len(chull)):
        distances.append(wasserstein_2d(chull[i], gt_hulls[i]))
    
    return np.mean(distances)

def eval_model(model_fp, modeltype, datafile, device, n_layers, od = 8,):

    data_fp = 'elliptical-50'
    inter_fp = os.path.join('/data/oren/coreset/models', data_fp, model_fp, 'record')
    fp = os.path.join('/data/oren/coreset/models', data_fp, model_fp)


    if modeltype == 'ConvexHullNN':
        model = ConvexHullNN(2, 16, 128, od, 2)
    else:
        print(n_layers)
        model = ConvexHullNN_new(2, 16, 1024, od, 2, n_layers)

    
    # state_dict_path = os.path.join(inter_fp, 'model_40.pt')
    state_dict_path = os.path.join(fp, 'final_model.pt')
    print(state_dict_path)
    
    model.load_state_dict(torch.load(state_dict_path), strict = False)
    model = model.to(device)

    raw_data = np.load(os.path.join('../../../../../data/oren/coreset/data', datafile))[:5000] #datafile
    dataloader, gt = npz_to_batches(raw_data, 128) #batch size = 128


    chull = []
    gt_hulls = []
    
    for batch in dataloader:

        if modeltype == 'ConvexHullNN':
        
            batch = batch.to(device)
            out = model(batch)
            out = out.view(-1, raw_data[0].shape[0], out.data.size(-1)) #25 for circle/ellipse, 50 for triangle data
        
            out = F.softmax(out, dim=1)
            out = out.view(-1, out.data.size(-1))
        
            chull += [tensor.cpu().detach().numpy() for tensor in get_approx_chull(out, batch)]
        else:
            #new model
            batch = batch.to(device)
            out = model(batch)
            split_tensors = out.data.split(8)
            chull += [tensor.cpu().detach().numpy() for tensor in split_tensors]
    
    for batch in gt:
        gt_hulls += [tensor.cpu().detach().numpy() for tensor in batch]

    errs = []
    for i in range(len(chull)):
        pt_set = raw_data[i][:, :2]

        errs.append(directional_err(chull[i], pt_set, 2, n=1000))


    model_name = '/'.join(model_fp.rsplit('/', 2)[-2:])
    return model_name, datafile, np.mean(errs), np.std(errs)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--datafile', type=str)
    parser.add_argument('--modeltype', type=str)
    parser.add_argument('--model_fp', type=str)


    raw_data = np.load(datafile)
    dataloader, gt = npz_to_batches(raw_data, 128)

if __name__ == '__main__':
    main()