from pathlib import Path
import json
import sys
from itertools import chain

from tqdm import tqdm

import numpy as np

from compute_metrics import analyze_file, Instance, align_disparity, align_depth

def instance_from_folder(folder):
    parts = Path(folder).parts
    obj, variation, scene, id = parts[-4:]
    root = str(Path(*parts[:-4]))
    return Instance(root, obj, variation, scene, id)

def avg_error(model, align_method, full_scene, scale_only, path):
    errors = []
    for instance_folder in Path(path).iterdir():
        instance = instance_from_folder(instance_folder)
        if (res := analyze_file(model, align_method, instance, full_scene=full_scene, scale_only=scale_only, compare_with_base=False)) is not None:
            errors.append(res['absrel'])

    if not instance.base_included_in_folder:
        errors.append(analyze_file(model, align_method, instance.get_base(), full_scene=full_scene, scale_only=scale_only, compare_with_base=False)['absrel'])

    return np.mean(errors)

def accuracy_stability(model, align_method, full_scene, scale_only, path):
    errors = []
    for instance_folder in Path(path).iterdir():
        instance = instance_from_folder(instance_folder)
        if (res := analyze_file(model, align_method, instance, full_scene=full_scene, scale_only=scale_only, compare_with_base=False)) is not None:
            errors.append(res['absrel'])

    if not instance.base_included_in_folder:
        errors.append(analyze_file(model, align_method, instance.get_base(), full_scene=full_scene, scale_only=scale_only, compare_with_base=False)['absrel'])

    errors = np.array(errors)

    return ((errors - errors.mean())**2).sum() / (len(errors) - 1)

def self_inconsistency(model, align_method, full_scene, scale_only, path):
    errors = []
    for instance_folder in Path(path).iterdir():
        instance = instance_from_folder(instance_folder)
        if (res := analyze_file(model, align_method, instance, full_scene=full_scene, scale_only=scale_only, compare_with_base=True)) is not None:
            errors.append(res['absrel'])

    errors = np.array(errors)

    return (errors**2).sum() / len(errors)

def agg_avg_errors(model, align_method, full_scene, scale_only, paths):
    errors = [avg_error(model, align_method, full_scene, scale_only, path) for path in tqdm(paths)]
    errors = np.array(errors)

    return np.mean(errors)

def agg_accuracy_stability(model, align_method, full_scene, scale_only, paths):
    stabilities = np.array([accuracy_stability(model, align_method, full_scene, scale_only, path) for path in tqdm(paths)])

    return np.sqrt(stabilities).mean()

def agg_self_inconsistency(model, align_method, full_scene, scale_only, paths):
    inconsistencies = np.array([self_inconsistency(model, align_method, full_scene, scale_only, path) for path in tqdm(paths)])

    return np.sqrt(inconsistencies).mean()

class FileLoaderModel:
    def __init__(self, prediction_root, name):
        self.name = name
        self.prediction_root = prediction_root
        match name:
            case "DepthAnything" | "DepthAnything-Base" | "DepthAnything-Small":
                self.align = align_disparity
            case "DepthAnythingV2" | "DepthAnythingV2-Base" | "DepthAnythingV2-Small":
                self.align = align_disparity
            case "DepthPro":
                self.align = align_depth
            case 'Marigold':
                self.align = align_depth
            case 'Metric3DV2' | 'Metric3DV2-ConvNext-Large' | 'Metric3DV2-ConvNext-Tiny' | 'Metric3DV2-ViT-Large' | 'Metric3DV2-ViT-Small':
                self.align = align_depth
            case 'MiDaS' | 'MiDaS-BEiT-Base-384' | 'MiDaS-BEiT-Large-384':
                self.align = align_disparity
            case 'MoGe':
                self.align = align_depth
            case 'UniDepth' | 'UniDepth-Base' | 'UniDepth-Small':
                self.align = align_depth
            case 'ZoeDepth':
                self.align = align_depth

    def predict(self, I):
        depth_file = Path(self.prediction_root) / I.obj / I.variation_type / I.scene / self.name / f"{I.id}.npy"
        return np.load(depth_file)

if __name__ == "__main__":
    PRED_ROOT = "<PREDICTIONS_ROOT>"
    DATASET_ROOT = "<DATASET_ROOT>"
    model = FileLoaderModel(PRED_ROOT, "DepthPro")

    full_scene = False
    scale_only = False

    paths = list(chain(
        (Path(DATASET_ROOT) / "chairs" / "cam_pan_tilt").iterdir(),
        (Path(DATASET_ROOT) / "desks" / "cam_pan_tilt").iterdir(),
        (Path(DATASET_ROOT) / "cabinets" / "cam_pan_tilt").iterdir(),
        (Path(DATASET_ROOT) / "fishes" / "cam_pan_tilt").iterdir(),
        (Path(DATASET_ROOT) / "cactus" / "cam_pan_tilt").iterdir(),
    ))

    err = agg_avg_errors(model, model.align, full_scene=full_scene, scale_only=scale_only, paths=paths)
    stab = agg_accuracy_stability(model, model.align, full_scene=full_scene, scale_only=scale_only, paths=paths)
    self_con =  agg_self_inconsistency(model, model.align, full_scene=full_scene, scale_only=scale_only, paths=paths)

    print("err:", err)
    print("stability:", stab)
    print("self-inconsistency:", self_con)