import igl
from metagen.util import *
import numpy as np
from .validation import *

import pickle
from dataclasses import dataclass, asdict

from functools import cache

@dataclass
class DatasetInfo:
    rounding_factors: dict
    property_map: dict
    property_ranges: dict
    gt_examples: dict
    to_predict: list

@cache
def load_mesh_VF(example: str):
    mesh = load_mesh(example)
    return np.asarray(mesh.vertices), np.asarray(mesh.faces)

def IoU(v1, v2):
    return (v1 * v2).sum() / (v1 | v2).sum()

def chamfer_dist(m1, m2):
    ds12,_,_ = igl.point_mesh_squared_distance(m1[0], *m2)
    ds21,_,_ = igl.point_mesh_squared_distance(m2[0], *m1)
    d12 = ds12.mean()
    d21 = ds21.mean()
    return (d12 + d21) / 2

@cache
def get_voxels(pred):
    return load_voxels(open_file(path_append(pred, 'vox_active_cells.txt')))

def eval_vs_gt(pred, pred_info, dataset_info: DatasetInfo ):
    # pred is either key to lookup the results or the llm inferred code
    if pred_info['task'] == 'Generate':
        pred_sim_results = {k: float(v) for k,v in load_and_format_properties(pred).items()}
        pred_voxels = get_voxels(pred)
        pred_mesh = load_mesh_VF(pred)
    else:
        # Get Stats from Predict
        # Round predictions (they should be already)
        pred_sim_results = {k:float(round_to_delineation(v, dataset_info.rounding_factors[k])) for k,v in pred.items()}
        pred_voxels = None
        pred_mesh = None
    gt = dataset_info.gt_examples[pred_info['id_number']]
    gt_sim_results = {k: float(v) for k,v in load_and_format_properties(gt).items()}
    if pred_voxels is not None:
        gt_voxels = get_voxels(gt)
        gt_mesh = load_mesh_VF(gt)
        iou = float(IoU(pred_voxels, gt_voxels))
        chamf = float(chamfer_dist(pred_mesh, gt_mesh))
    else:
        iou = None
        chamf = None
    # Now compare the prediction values
    error = {}
    accuracy = {}
    for k in pred_sim_results:
        pred_prop = pred_sim_results[k]
        gt_prop = gt_sim_results[k]
        acc = (pred_prop == gt_prop)
        pmin, pmax = dataset_info.property_ranges[k]
        pred_scaled = (pred_prop - pmin) / (pmax - pmin)
        gt_scaled = (gt_prop - pmin) / (pmax - pmin)
        abs_error = abs(pred_scaled - gt_scaled)
        error[k] = abs_error
        accuracy[k] = acc
    mean_error = sum(error.values()) / len(error)
    mean_accuracy = sum(accuracy.values()) / len(accuracy)
    return {
        'IoU':iou,
        'ChamferDistance':chamf,
        'error': error,
        'accuracy': accuracy,
        'MeanError': mean_error,
        'MeanAccuracy': mean_accuracy,
        'GroundTruthLocation':gt
    }

def GatherDatasetInfo(stats_path, dataset_info_path):
    with open(stats_path, 'rb') as f:
        dataset_stats = pickle.load(f)
    
    to_predict = dataset_stats['property_order']
    to_predict = ['rho' if k == 'D' else ('nu' if k == 'NU' else k) for k in to_predict]
    rounding_factors = {
        'E': 0.1,
        'G': 0.05,
        'K': 0.5,
        'rho': 0.1,
        'nu': 0.1
    }
    property_map = {
        'E':'E',
        'K':'K',
        'G':'G',
        'D':'rho',
        'NU':'nu'
    }
    property_ranges = {property_map[k]:list(map(float, v)) for k,v in dataset_stats['property_ranges'].items()}
    gt_examples = dataset_stats['example_order']


    dataset_info = DatasetInfo(
        rounding_factors, property_map, property_ranges, gt_examples, to_predict
    )

    with open(dataset_info_path,'w') as f:
        json.dump(asdict(dataset_info), f)


def ExtractNovaGenerations(gen_dir, generated_code_dir = 'codefiles', model_name = 'NovaPro', dryrun=True):
    filepaths = [os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(gen_dir) for filename in filenames if 'manifest' not in filename]
    for fp in tqdm(filepaths):
        recs = load_jsonl(fp)
        for rec in tqdm(recs):
            example_code = rec['recordId']
            response = rec['modelOutput']['output']['message']['content'][0]['text']
            record_info = parse_record_id(example_code)
            if record_info['task'] != 'Generate':
                continue
            rep = record_info['rep_type']
            lang = 'python' if rep == 'DSL' else 'json'

            blocks = extract_and_classify_blocks(response)
            if len(blocks[lang]) != 1:
                reason = 'BadModelOutput'
                code = "}}BAD GENERATION{{" # This will break the parser and show up in our results
            else:
                code = blocks[lang][0]
            code_file_name = model_name + '_' + example_code + ('.py' if lang == 'python' else '.json')
            code_file_path = path_append(generated_code_dir, model_name + '/' + rep + '/' + code_file_name)

            if dryrun:
                print(code_file_path)
            else:
                os.makedirs(os.path.dirname(code_file_path), exist_ok=True)
                with open(code_file_path,'w') as f:
                    f.write(code)

