import argparse
from gzip import GzipFile
import numpy as np
from tqdm import tqdm
import torch
from torch import nn

from main import get_conf_dict, get_img_pathes, get_img_pathes_gt
from pathlib import Path

from pipeline.metrics.LDL.metric import get_refined_artifact_map
from pipeline.metrics.bd_jup.metric import CombDet
from pipeline.metrics.ssim.method import cal_ssim
from pipeline.metrics.ssm_jup.metric import Structure_Similarity
from pipeline.metrics.vryl_texture.mm_vers import VRYL
from pipeline.metrics.vryl_texture_mod.mm_vers_lpips import VRYL as VRYL_lpips
from pipeline.metrics.vryl_texture_mod.mm_vers_dists import VRYL as VRYL_dists
from pipeline.metrics.dists.dists import method as dists


DESRA = False


IMAGES = Path("<path_to_dataset>")

gt = "RLFN"
rf = "bicubic"

OUTPUT = Path(f"output/{'desra_' if DESRA else ''}heatmaps_gt{gt}_rf{rf}")

conf_dict = get_conf_dict("gt_conf_desra.csv" if DESRA else "gt_conf.csv")

device = torch.device("cuda")
# device = torch.device("cpu")


class NeuralNetwork(nn.Module):
    def __init__(self, n_features: int = 3):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(n_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits


_gt_img_paths = None
_img_paths = None


def get_gt_img_paths():
    global _gt_img_paths
    if _gt_img_paths is None:
        _gt_img_paths = get_img_pathes_gt(str(IMAGES), conf_dict, gt=gt, rf=rf)
    return _gt_img_paths


def get_img_paths():
    global _img_paths
    if _img_paths is None:
        _img_paths = get_img_pathes(str(IMAGES), gt=gt, rf=rf)
    return _img_paths


def evaluate(method_name, method, big_set=False):
    results = OUTPUT
    fails = []

    if big_set:
        paths = dict(get_img_paths())
    else:
        paths = dict(get_gt_img_paths())

    results.mkdir(exist_ok=True)
    results = results / method_name
    results.mkdir(exist_ok=True)

    paths = {k: v for k, v in paths.items() if not (results / f"{Path(v[1]).stem}.npy.gz").exists()}
    print(f"will compute {len(paths)} images")

    for hr_path, sr_path, rf_path in tqdm(paths.values(), ncols=80, desc="evaluate"):
        sr_raw_name = Path(sr_path).stem

        try:
            heatmap = method(hr_path, sr_path, rf_path)

            with GzipFile(results / f"{sr_raw_name}.npy.gz", "w") as f:
                np.save(f, heatmap)
        except Exception as e:
            print(
                f"error evaluating {method_name} on {sr_raw_name} ({Path(hr_path).stem}, {Path(sr_path).stem}, {Path(rf_path).stem})"
            )
            print(e)
            fails.append(sr_raw_name)

    return fails


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--all-images",
        action="store_true",
        default=False,
        help="Whether to process all SR images, rather than only GT ones (default: %(default)s)",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    LDL = get_refined_artifact_map

    opts = {
        "threshold1": 0.9,
        "threshold2": 0.99,
        "erqa_block_size": 8,
        "lpips_block_size": 32,
        "erqa_stride": 8,
        "lpips_stride": 16,
        "erqa_weight": 0.4,
        "final_threshold": 0.15,
        "version": "0.0",
        "n_device": 0,
        "bin_mode": "dilation",
    }

    bd_jup = CombDet(**opts)

    ssim = cal_ssim

    ssm_jup = Structure_Similarity(device="0")

    opts = {
        "device": "cuda:0",
        "t1": 0,  # heatmap > t1
        "t2": 0.055,  # heatmap > t2
    }
    vryl_texture = VRYL(**opts)

    opts = {
        "device": "cuda:0",
        "t1": 0,  # heatmap > t1
        "t2": 0.05,  # heatmap > t2
    }
    vryl_texture_dists = VRYL_dists(**opts)

    opts = {
        "device": "cuda:0",
        "t1": 0,  # heatmap > t1
        "t2": 0.05,  # heatmap > t2
    }
    vryl_texture_lpips = VRYL_lpips(**opts)

    def nn_ptl(path: str):
        ckpt = torch.load(path, map_location=device, weights_only=True)

        n_features = 3
        name = path.split("/")[2]
        for component in name.split("-"):
            if component.startswith("no"):
                n_features -= 1

        model = NeuralNetwork(n_features)
        model.load_state_dict({k.removeprefix("model."): v for k, v in ckpt["state_dict"].items()})
        model.to(device)
        model.eval()

        def npy_metric(name, base=None, gzip=False, ext=".npy"):
            base = base or OUTPUT / name

            def metric(hr_path, sr_path, rf_path):
                try:
                    if gzip:
                        with GzipFile(base / f"{Path(sr_path).stem}{ext}") as f:
                            return np.load(f)
                    else:
                        return np.load(base / f"{Path(sr_path).stem}{ext}")
                except Exception:
                    print("failed: ", sr_path)
                    raise

            return metric

        methods = {
            "dists": npy_metric("dists", gzip=True, ext=".npy.gz"),
            "ssm_jup": npy_metric("ssm_jup", gzip=True, ext=".npy.gz"),
            "bd_jup": npy_metric("bd_jup", gzip=True, ext=".npy.gz"),
        }

        for component in name.split("-"):
            if component.startswith("no"):
                methods = {k: v for k, v in methods.items() if k != component[2:]}

        def method(hr_path, sr_path, rf_path):
            heatmaps = []
            for method in methods.values():
                heatmaps.append(method(hr_path, sr_path, rf_path))
            X = np.stack(heatmaps, axis=-1).astype("float32").reshape(-1, len(heatmaps))
            res = model(torch.tensor(X, device=device)).numpy(force=True)
            res = res.reshape(heatmaps[0].shape)
            return res

        return method

    methods = {
        "bd_jup": bd_jup,
        "ssm_jup": ssm_jup,
        "dists": dists,
        "nn-20250421-gtgt-e30": nn_ptl("nn-20250421-gtgt-e30.ckpt"),
    }

    all_fails = {}
    for name, method in tqdm(methods.items(), ncols=80):
        fails = evaluate(name, method, big_set=args.all_images)
        if len(fails) > 0:
            all_fails[name] = fails

    for method_name, fails in all_fails.items():
        print(f"- {method_name}: failed on {len(fails)} files")
        print(f"  Fails: {fails}\n")


if __name__ == "__main__":
    main()
