import torch
import numpy as np
import Model_pointnet
import train_pointnet_params as Params
from data import partialDataset
import os
import tools
import random

def test(test_folder, model, dataset,supervision):
    seed = 302
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    geodesic_errors_lst = np.array([])
    l = 0
    if dataset == 'ModelNet':
        test_path_list = [os.path.join(test_folder, i) for i in os.listdir(test_folder)]
        for i in range(len(test_path_list)):
            path = test_path_list[i]
            tmp = torch.load(path)
            pc2 = tmp['pc'].cpu().cuda()
            gt_rmat = tmp['rgt'].cpu().cuda()
            out_rmat, out_nd, _,_ = model(pc2.transpose(1, 2))
            if supervision == 'self':
                l += ((torch.bmm(pc2, out_rmat)- torch.bmm(pc2, gt_rmat))**2).mean()
            else:
                l += ((gt_rmat - out_rmat) ** 2).sum()
            geodesic_errors = np.array(
                tools.compute_geodesic_distance_from_two_matrices(gt_rmat, out_rmat).data.tolist())  # batch
            geodesic_errors = geodesic_errors / np.pi * 180
            geodesic_errors_lst = np.append(geodesic_errors_lst, geodesic_errors)
        l /= len(test_path_list)
    else:
        test_dataset = partialDataset(test_folder)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=10,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        for pc2, gt_rmat in test_loader:
            pc2 = pc2.float().cuda()
            gt_rmat = gt_rmat.float().cuda()
            out_rmat, out_nd, _,_ = model(pc2.transpose(1, 2))
            l += ((gt_rmat - out_rmat) ** 2).sum()
            geodesic_errors = np.array(
                tools.compute_geodesic_distance_from_two_matrices(gt_rmat, out_rmat).data.tolist())  # batch
            geodesic_errors = geodesic_errors / np.pi * 180
            geodesic_errors_lst = np.append(geodesic_errors_lst, geodesic_errors)


        l /= len(test_loader)

    return geodesic_errors_lst, l


"""   
        if(geodesic_errors.max()>170):
            worst_id = np.argmax(geodesic_errors)
            head, tail = ntpath.split(pc_fn)
            print (tail)
            save_point_clouds(np.array(pc2[worst_id].data.tolist()), np.array(out_pc2[worst_id].data.tolist()), fn=save_out_pc_folder+tail[0:len(tail)-3]+"xyz")
        for error  in geodesic_errors:
            if(error>90):
                num_big_geodesic_error= num_big_geodesic_error+1
        
    
    avg_geodesic_error = geodesic_errors_lst.mean()
    max_geodesic_error = geodesic_errors_lst.max()
    print ("avg geodesic_error: " + str(avg_geodesic_error))
    print ("max geodesic_error: " + str(max_geodesic_error))   
    print (("big_geodesic_error_rate: ") + str(num_big_geodesic_error/len(geodesic_errors_lst)))
    
    if(model.regress_t==True):
        avg_t_error = t_errors_lst.mean()
        max_t_error = t_errors_lst.max()
        print ("avg t_error: " + str(avg_t_error))
        print ("max t_error: " + str(max_t_error))   
    return geodesic_errors_lst, t_errors_lst
"""
 
def get_error_lst(test_folder, weight_folder, model_name_lst,  iteration, dataset, supervision):
    errors_lst=[]
    for (weight_sub_folder, out_rotation_mode, model_kind) in model_name_lst:
        weight_fn  = weight_folder + weight_sub_folder+"/weight/model_%07d.weight"%iteration
        save_path = weight_folder + weight_sub_folder
        #weight_fn = weight_folder+'best_model/3.weight'
        with torch.no_grad():
            model = Model_pointnet.Model(out_rotation_mode=out_rotation_mode, kind=model_kind)
            print("Load " + weight_fn)
            f = torch.load(weight_fn)
            model.load_state_dict(f['model'])
            model.cuda()
            model.eval()

            geodesic_errors,l = test(test_folder, model, dataset, supervision)
            np.save(save_path, geodesic_errors)
            print("Loss: ", l)
        errors_lst = errors_lst + [(geodesic_errors, out_rotation_mode)]
    return errors_lst



if __name__ == "__main__":
    param=Params.Parameters()
    test_folder = os.path.join('../../../airplane', 'test_sampled')
    #test_folder = os.path.join('../../../ModelNet_partial_dataset/airplane', 'test')
    #test_path_list = [os.path.join(test_folder, i) for i in os.listdir(test_folder)]
    #test_folder = '/orion/u/yijiaw/artpose/data/nocs_data/render/test/6'
    #test_dataset = NOCSDataset(test_folder)

    weight_folder = "../experiments/"
    model_name_lst = [("self/svd_inf", "svd9d", 3)]

    iteration = 30000
    # iteration=200000
    errors_lst = get_error_lst(test_folder, weight_folder, model_name_lst, iteration, 'ModelNet', 'self')

    for errors, name in errors_lst:
        print(name)
        print("median:"+str(np.round(np.median(errors),2)))
        print("avg:" + str(np.round(errors.mean(), 2)))
        print("max:" + str(np.round(errors.max(), 2)))
        print("std:" + str(np.round(np.std(errors), 2)))
        print("1 accuracy:"+str(np.round((errors<1).sum()/len(errors),3)))
        print("3 accuracy:" + str(np.round((errors < 3).sum() / len(errors), 3)))
        print("5 accuracy:"+str(np.round((errors<5).sum()/len(errors),3)))