import argparse
import json
import os
import shutil
import sys

import numpy as np
import torch
from numpy.typing import ArrayLike

from decord import VideoReader

from videos.video_distribution import get_fvd
from videos.video_quality import (
    get_lpips, get_psnr, get_ssim, get_video_smoothness, get_video_magnitude, get_video_smoothness_v2,
    get_objective_quality, get_subjective_quality, get_scene_consistency, get_scene_consistency_v2, get_scene_consistency_v3,
    get_objective_quality_v2
)


# ----------------------------- I/O helpers ----------------------------------

def get_video(video_path: str) -> torch.Tensor:
    """Load a video as a float tensor in [0,1], shape (T, H, W, C)."""
    vr = VideoReader(video_path)
    frames = [(torch.from_numpy(f.asnumpy()).to(torch.float32)) / 255.0 for f in vr]
    return torch.stack(frames)


def get_imgs(dir_path: str) -> list[str]:
    """Return a sorted list of full image paths in a directory."""
    files = [os.path.join(dir_path, f) for f in os.listdir(dir_path)]
    return sorted(files)


# ----------------------------- Sheet row printer -----------------------------

def print_sheet_row(metrics: dict, include_header: bool = False):
    """
    Print a single =SPLIT("...") row for Google Sheets.
    Columns: distribution FVD → quality metrics (smoothness/objective/subjective/consistency).
    """
    q   = metrics.get('quality', {}) or {}
    obj = q.get('objective_quality', {}) or {}
    sm  = q.get('smoothness', []) or []

    mse   = sm[0] if len(sm) > 0 else None
    ssim  = sm[1] if len(sm) > 1 else None
    lpips = sm[2] if len(sm) > 2 else None
    cta   = obj.get('contrast_transfer_accuracy', obj.get('contrast_transfer accuracy'))

    cols = [
        ('FVD', metrics.get('distribution', {}).get('fvd')),
        ('mse', mse),
        ('ssim', ssim),
        ('lpips', lpips),
        ('flow', q.get('magnitude')),
        ('frame_dynamic_range_proxy', obj.get('frame_dynamic_range_proxy')),
        ('mtf50', obj.get('mtf50')),
        ('mtf10', obj.get('mtf10')),
        ('contrast_transfer_accuracy', cta),
        ('edge_rise_time', obj.get('edge_rise_time')),
        ('total_distortion', obj.get('total_distortion')),
        ('flare_attenuation', obj.get('flare_attenuation')),
        ('gradient_entropy', obj.get('gradient_entropy')),
        ('blur_extent', obj.get('blur_extent')),
        ('chroma_aberration', obj.get('chroma_aberration')),
        ('sequence_dynamic_range_proxy', obj.get('sequence_dynamic_range_proxy')),
        ('fmp_alias', obj.get('fmp_alias')),
        ('mmp_alias', obj.get('mmp_alias')),
        ('subjective_quality', q.get('subjective_quality')),
        ('scene_consistency', q.get('scene_consistency')),
    ]

    def fmt(v):
        if isinstance(v, (int, float)):
            return f"{v:.4f}"
        return "" if v is None else str(v)

    headers = ",".join(k for k, _ in cols)
    values  = ",".join(fmt(v) for _, v in cols)

    if include_header:
        print(f'=SPLIT("{headers}", ",")')
    print(f'=SPLIT("{values}", ",")')


# --------------------------------- Main -------------------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Video evaluation')
    parser.add_argument('--root_path', type=str, required=True)  # predictions parent
    parser.add_argument('--gt_path', type=str, required=True)    # JSON file with GT dirs
    parser.add_argument('--outdir', type=str, default='./vis_depth')
    parser.add_argument('--split', type=str, default='gt')
    parser.add_argument('--action', type=str, default='free')
    parser.add_argument('--debug', type=int, default=1)
    args = parser.parse_args()

    split = args.split
    act_dir = args.action if split != 'gt' else ''
    runs = [r for r in os.listdir(args.root_path) if os.path.isdir(os.path.join(args.root_path, r))]
    debug = int(args.debug)

    print(f'Total runs: {len(runs)}')

    save_base = os.path.join('outputs', os.path.basename(os.path.normpath(args.outdir)))
    os.makedirs(save_base, exist_ok=True)

    vis_base = os.path.join(save_base, f'video-vis-{split}_{act_dir}')
    os.makedirs(vis_base, exist_ok=True)

    with open(args.gt_path, 'r') as f:
        gt_json = json.load(f)

    preds: list[torch.Tensor] = []
    gts: list[torch.Tensor] = []

    pred_imgs: list[list[str]] = []
    gt_imgs: list[list[str]] = []

    # Prepare GT frames for FVD (folder-per-sequence with continuous filenames)
    gt_video_fvd_base = os.path.splitext(args.gt_path)[0] + '+fvd'
    os.makedirs(gt_video_fvd_base, exist_ok=True)

    for gt_base in gt_json:
        # Load GT video tensor
        mp4_path = os.path.join(gt_base, os.path.basename(gt_base) + '.mp4')
        gt_video_frames = get_video(mp4_path)
        gts.append(gt_video_frames)

        # Collect GT image paths
        cam_dir = os.path.join(gt_base, 'CAM_F0')
        gt_img_list = get_imgs(cam_dir)
        gt_imgs.append(gt_img_list)

        # Copy images to FVD input dir
        name = "+".join(os.path.normpath(gt_base).split(os.sep)[-2:])
        fvd_path = os.path.join(gt_video_fvd_base, name)
        if not os.path.exists(fvd_path):
            os.makedirs(fvd_path, exist_ok=True)
            for idx, img in enumerate(gt_img_list):
                if idx == 0:
                    continue  # keep alignment with original behavior
                dst = os.path.join(fvd_path, f'{idx:05d}.jpg')
                try:
                    shutil.copy2(img, dst)
                except Exception:
                    pass  # ignore copy errors silently

    # Prepare prediction frames for FVD
    video_fvd_base = args.root_path.rstrip(os.sep) + '+fvd'
    os.makedirs(video_fvd_base, exist_ok=True)

    for run in runs:
        log_base = os.path.join(args.root_path, run, split, act_dir)

        video_path = os.path.join(log_base, 'video.mp4')
        if not os.path.exists(video_path):
            # skip silently if missing
            continue
        video_frames = get_video(video_path)
        preds.append(video_frames)

        img_dir = os.path.join(log_base, 'images')
        img_list = get_imgs(img_dir) if os.path.isdir(img_dir) else []
        pred_imgs.append(img_list)

        parts = os.path.normpath(log_base).split(os.sep)
        name = "+".join(parts[-3:]) if len(parts) >= 3 else "+".join(parts)
        fvd_path = os.path.join(video_fvd_base, name)
        if not os.path.exists(fvd_path):
            os.makedirs(fvd_path, exist_ok=True)
            for idx, img in enumerate(img_list):
                if idx == 0:
                    continue
                ext = os.path.splitext(img)[1] or '.png'
                dst = os.path.join(fvd_path, f'{idx:05d}{ext}')
                try:
                    shutil.copy2(img, dst)
                except Exception:
                    pass

    # Guard empty sets
    if len(preds) == 0 or len(gts) == 0:
        print("No data found to evaluate; exiting.")
        sys.exit(0)

    preds_t = torch.stack(preds)  # (N, T, H, W, C)
    gts_t   = torch.stack(gts)
    print(preds_t.shape, gts_t.shape)

    # ----------------------------- Metrics -----------------------------------

    fvd = -1
    # Example usage:
    # fvd = get_fvd(preds_t, gts_t)
    # or FVD on folders:
    # fvd = get_fvd(video_fvd_base, gt_video_fvd_base)

    video_smoothness = get_video_smoothness_v2(pred_imgs)
    gt_video_smoothness = get_video_smoothness_v2(gt_imgs) if debug else -1

    video_magnitude = -1
    gt_video_magnitude = -1
    # Example:
    # video_magnitude = get_video_magnitude(pred_imgs)
    # if debug:
    #     gt_video_magnitude = get_video_magnitude(gt_imgs)

    objective_quality = -1
    gt_objective_quality = -1
    # Example:
    # objective_quality = get_objective_quality_v2(pred_imgs)
    # if debug:
    #     gt_objective_quality = get_objective_quality_v2(gt_imgs)

    subjective_quality = -1
    gt_subjective_quality = -1
    # Example:
    # subjective_quality = get_subjective_quality(pred_imgs)
    # if debug:
    #     gt_subjective_quality = get_subjective_quality(gt_imgs)

    scene_consistency = get_scene_consistency_v3(pred_imgs)
    gt_scene_consistency = get_scene_consistency_v3(gt_imgs) if debug else -1

    metrics = {
        'distribution': {
            'fvd': fvd
        },
        'quality': {
            'smoothness': video_smoothness,
            'magnitude': video_magnitude,
            'objective_quality': objective_quality,
            'subjective_quality': subjective_quality,
            'scene_consistency': scene_consistency
        }
    }
    if debug:
        metrics['gt_quality'] = {
            'smoothness': gt_video_smoothness,
            'magnitude': gt_video_magnitude,
            'objective_quality': gt_objective_quality,
            'subjective_quality': gt_subjective_quality,
            'scene_consistency': gt_scene_consistency
        }

    # ----------------------------- Convert to native -------------------------

    def _native(x):
        """NumPy → Python scalar/list; passthrough otherwise."""
        if isinstance(x, np.generic):
            return x.item()
        if isinstance(x, np.ndarray):
            return x.item() if x.size == 1 else x.tolist()
        return x

    for key, subdict in metrics.items():
        for sub_key, sub_val in subdict.items():
            metrics[key][sub_key] = _native(sub_val)

    # ----------------------------- Print & dump ------------------------------

    print(f"Metrics for {split} - {act_dir}:")
    for key, subdict in metrics.items():
        print(f"{key}:")
        for sub_key, sub_val in subdict.items():
            if isinstance(sub_val, (float, int)):
                print(f"  {sub_key}: {sub_val:.4f}")
            else:
                print(f"  {sub_key}: {sub_val}")

    metrics_path = os.path.join(save_base, f'video-{split}-{act_dir}.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)

    print_sheet_row(metrics, include_header=False)
