import argparse, os, json
import torch
from utilities.builder import build_proc_from_run_dir
from config import get_dataset_cfg
from exp.exp_basic import Exp_Basic, ExpConfigs
from exp.exp_memKNO import Exp_MemKNO

def _json_default(o):
    try:
        import torch
        if isinstance(o, torch.Tensor):
            return o.item() if o.numel() == 1 else o.detach().cpu().tolist()
    except Exception:
        pass
    try:
        import numpy as np
        if isinstance(o, np.generic):
            return o.item()
        if isinstance(o, np.ndarray):
            return o.tolist()
    except Exception:
        pass
    return str(o)


def main():
    parser = argparse.ArgumentParser(
        description="Command-line evaluation for memKNO."
    )
    parser.add_argument("--eval_mode", type=str, default="all")    # "all" / "phase1"
    parser.add_argument("--phase1_path", type=str)
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--gpu", type=int, default=0, help="GPU id; ignored if no CUDA.")
    parser.add_argument("--dataset", type=str, default="ns_1e-3")
    parser.add_argument("--seq_id", type=int, default=0)
    parser.add_argument("--rollout_steps", type=int, default=15)
    parser.add_argument("--rom", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--eval_long_term", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--lt_steps", type=int, default=10)
    parser.add_argument("--lt_bs", type=int, default=4)
    parser.add_argument("--lt_traj_id", type=int, default=0)

    parser.add_argument("--traj_id", type=int, default=0)
    parser.add_argument("--t0", type=int, default=0)


    args = parser.parse_args()
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    dataset_cfg = get_dataset_cfg(name=args.dataset)
    data_path = dataset_cfg.DATA_PATH
    proc = build_proc_from_run_dir(run_dir=args.model_path, dataset=args.dataset)

    cfgs = Exp_Basic.load_all_configs(args.model_path)
    model_cfg_dict  = cfgs["model_cfg"]
    exp_cfg_dict    = cfgs["exp_cfg"] 
    exp_cfg_dict = dict(exp_cfg_dict)
    if isinstance(exp_cfg_dict.get("device"), str):
        exp_cfg_dict["device"] = device
    exp_cfg = ExpConfigs(**exp_cfg_dict)

    # print(exp_cfg)

    exp = Exp_MemKNO(args=None, exp_cfg=exp_cfg, model_cfg=model_cfg_dict, data_processor=proc)
    # print(model_cfg_dict)

    if not args.rom:
        _ = exp.load_phase1_ckpt(path=os.path.join(args.phase1_path, "phase1_best_rec.pth"), clip_positive_symmetric=False)
        out_dir_phase1 = os.path.join(args.model_path, "vis/phase1")
        os.makedirs(out_dir_phase1, exist_ok=True)
        exp.plot_Ad_spectrum(save_dir=out_dir_phase1)
        exp.visualize_random_rollout(
            group="test", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase1, "test"), mode="all", dyn_type="linear"  # "last" or "all"
        )
        exp.visualize_random_rollout(
            group="train_eval", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase1, "train"), mode="all", dyn_type="linear"
        )
        exp.visualize_rollout_by_index(
            group="train_eval", seq_index=args.seq_id, rollout_steps=args.rollout_steps,  
            out_dir=os.path.join(out_dir_phase1, f"idx_{args.seq_id}"), mode="all", dyn_type="linear"
        )
        ################## Evaluate Reconstrction ##################
        exp.visualize_random_rollout(
            group="test", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase1, "recon/test"), mode="all", dyn_type="recon"
        )
        exp.visualize_random_rollout(
            group="train_eval", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase1, "recon/train"), mode="all", dyn_type="recon"
        )
    
    if args.eval_mode == "all":
        info = exp.load_from_ckpt(ckpt_path=os.path.join(args.model_path, "model_tr_best.pth"))
        print(info)
        """print(info)
        if exp.whiten_scale is None:
            exp.fit_diag_whitening_from_phase1(group="train", max_batches=16)"""
        assert exp.whiten_scale is not None    ###########################################################
        # print(exp.whiten_scale)
        out_dir_phase2 = os.path.join(args.model_path, "vis/phase2")
        os.makedirs(out_dir_phase2, exist_ok=True)

        exp._ensure_loader("train_eval")
        exp._ensure_loader("test")
        train_errs = exp.evaluate(exp.train_eval_loader)
        test_errs = exp.evaluate(exp.test_loader)
        with open(os.path.join(args.model_path, "metrics_train.json"), "w") as f:
            json.dump(train_errs, f, indent=2, default=_json_default)
        with open(os.path.join(args.model_path, "metrics_test.json"), "w") as f:
            json.dump(test_errs, f, indent=2, default=_json_default)
        
        exp.visualize_random_rollout(
            group="test", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase2, "test"), mode="all", dyn_type="memory"  # "last" or "all"
        )
        exp.visualize_random_rollout(
            group="train_eval", batch_size=2, rollout_steps=args.rollout_steps,
            out_dir=os.path.join(out_dir_phase2, "train"), mode="all", dyn_type="memory"
        )
        exp.visualize_rollout_by_index(
            group="train_eval", seq_index=args.seq_id, rollout_steps=args.rollout_steps,  
            out_dir=os.path.join(out_dir_phase2, f"idx_{args.seq_id}"), mode="all", dyn_type="memory"
        )

        save_linear = not args.rom
        exp.save_rollout_comparison(group="train_eval", seq_index=args.seq_id, rollout_steps=args.rollout_steps, out_dir=os.path.join(args.model_path, f"idx_{args.seq_id}"),
                                    save_linear=save_linear)
        exp.save_rollout_comparison(group="train_eval", rollout_steps=args.rollout_steps, out_dir=os.path.join(args.model_path, "random_sample"),
                                    save_linear=save_linear)
        exp.save_rollout_tensors(out_dir=os.path.join(args.model_path, "saved_tensors/phase2"),
                                 traj_id=args.traj_id, t0=args.t0, rollout_steps=args.rollout_steps,
                                 phase="phase2")
        if save_linear:
            exp.save_rollout_tensors(out_dir=os.path.join(args.model_path, "saved_tensors/phase1"),
                                     traj_id=args.traj_id, t0=args.t0, rollout_steps=args.rollout_steps,
                                     phase="phase1")
            
        ############################## Latent Visualizations ##############################
        if args.dataset != "sst":
            if save_linear:
                exp.plot_dyn_energy_stats(exp.train_eval_loader, save_dir=os.path.join(args.model_path, "latent/linear_vs_memory"))
                exp.visualize_sample_evolution(
                    group="test", traj_id=args.traj_id, t0=args.t0, save_dir=os.path.join(args.model_path, "latent/linear_vs_memory"),
                    bg_mode="landscape", bg_alpha=0.75
                )
                # exp.visualize_phase_plane_2d(os.path.join(args.model_path, "latent/phase"))
            else:
                exp.plot_lowdim_time_series(
                    time_scale=4.0 if args.dataset=="ns_1e-3" else 1.0,
                    traj_id=args.traj_id, t0=0, steps=args.lt_steps,
                    use_center=True, use_whiten=True, use_projector=True,
                    save_dir=os.path.join(args.model_path, "rom/latent_modes"), fname_prefix="ytime",
                    topk_by_var=8
                )

        ############################## Long Term Visualizations ##############################
        if args.eval_long_term:
            out_dir_lt = os.path.join(args.model_path, "long_term")
            """exp.evaluate_long_trajs(out_dir=os.path.join(out_dir_lt, "train_full"), group="train", 
                                    rollout_steps=args.lt_steps, batch_size=args.lt_bs,
                                    save_pt=False, save_png=False,)
            exp.evaluate_long_trajs(out_dir=os.path.join(out_dir_lt, "test_full"), group="test", 
                                    rollout_steps=args.lt_steps, batch_size=args.lt_bs,
                                    save_pt=False, save_png=False,)
            exp.evaluate_long_by_indices(out_dir=os.path.join(out_dir_lt, "selected"), group="train", rollout_steps=args.lt_steps,
                                         indices=[0, 1, 2], save_pt=True, save_png=True,)"""
            exp.save_rollout_tensors(out_dir=out_dir_lt, traj_id=args.lt_traj_id, t0=0, rollout_steps=args.lt_steps, phase="phase2")


if __name__ == "__main__":
    main()

