import os, sys
sys.path.append(".")
from torch_geometric.datasets import ModelNet
import torch
from tqdm import tqdm
from real_world_assessment.utils import C_refinement, VD_refinement, test_symmetry, farthest_pretransform
from functools import partial

# basic settings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_methods = ["C_refinement"]

# dataset
down_sample_num = 256
farthest_pretransform = partial(farthest_pretransform, npoint=down_sample_num, device=device)
dataset_train = ModelNet(root='./ModelNet40/', name='40', pre_transform=farthest_pretransform, transform=None, pre_filter=None, train=True)
dataset_test = ModelNet(root='./ModelNet40/', name='40', pre_transform=farthest_pretransform, transform=None, pre_filter=None, train=False)
dataset = dataset_train + dataset_test

# Noise tolerance
errors = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
error_strings = [f"1e_{int(abs(torch.log10(torch.tensor(error))))}" for error in errors]
round_num = 1
print(error_strings, round_num)

# basic settings
size = len(dataset)
GAC_symm_graphs = {err_str: 0 for err_str in error_strings}
GAVD_symm_graphs = {err_str: 0 for err_str in error_strings}


# refinement symmetry testing
def run_refinement(refinement_method, method_name, result_dict):
    enclosing_radius_list = []
    for i, data in tqdm(enumerate(dataset), total=size):
        pos = data.pos.to(device)
        pos = pos - torch.mean(pos, dim=0, keepdim=True)
        pos = pos / torch.max(torch.norm(pos, dim=1))

        atom_type = torch.ones(pos.shape[0], dtype=torch.long, device=device)
        if refinement_method:
            atom_type = refinement_method(atom_type, pos, device, round_num)

        for j, error in enumerate(errors):
            symm_result, enclosing_radius = test_symmetry(atom_type, pos, error)
            if symm_result:
                result_dict[error_strings[j]] += 1
        enclosing_radius_list.append(enclosing_radius)

    enclosing_radius_tensor = torch.tensor(enclosing_radius_list)
    with open(f"./real_world_assessment/output/enclosing_radius_modelnet_{method_name}_{round_num}.pickle", "wb") as f:
        torch.save(enclosing_radius_tensor, f)
    
    
    for i, error in enumerate(errors):
        total_graphs = result_dict[error_strings[i]]
        print(f"{method_name}, {(error, round_num)}, {total_graphs}, {len(dataset)}, {(total_graphs / len(dataset)) * 100}%")

# test
if "C_refinement" in test_methods:
    run_refinement(C_refinement, "GAC", GAC_symm_graphs)

if "VD_refinement" in test_methods:
    run_refinement(VD_refinement, "GAVD", GAVD_symm_graphs)