import os
import json
import hydra
from argparse import Namespace
import datetime
from pprint import pprint
import numpy as np

import wandb

from src.utils.distributed import is_main_process


class NumpyEncoder(json.JSONEncoder):
    """Custom Numpy encoder that overrides the default json encoder to handle numpy arrays."""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

def initialize_wandb(args, disabled=True):
    if disabled:
        # for debugging
        wandb.init(config=args, mode="disabled")
    else:
        wandb.init(config=args)

    if wandb.run is not None:
        INVALID_PATHS = ["__old__", "checkpoints", "logs", "outputs", "results", "wandb"]
        wandb.run.log_code(
            exclude_fn=lambda path: any(
                [path.startswith(os.path.expanduser(os.getcwd() + "/" + i)) for i in INVALID_PATHS]
            )
        )
    return wandb


def wandb_log(dictionary: dict):
    if is_main_process():
        wandb.log(dictionary)

def _get_mask_suffix(args) -> str:
    if args.method.name == "tall_mask":
        return "tall_mask_ties" if args.method.use_ties else "tall_mask_ta"
    elif args.method.name == "mag_masking":
        return "mag_mask"
    elif args.method.name == "consensus":
        return f"k_{args.method.prun_thre_k}_ties" if args.method.use_ties else f"k_{args.method.prun_thre_k}_ta"
    else:
        return ""

def _get_lines_suffix(args) -> str:
    return "_lines" if args.method.apply_lines else ""

def _get_replace_suffix(args) -> str:
    if args.replace_layers:
        return f"_replace_layer_{args.replace_layers[0]}_num_layers={len(args.replace_layers)}"
    elif args.replace_components:
        return f"_replace_component_{args.replace_components[0]}"
    return ""

def _get_subspace_boosting(args):
    return f"base_method={args.method.base_method}_coef={args.method.svd_thresh}" if args.method.name == "subspace_boosting" else ""

def _filename(args, model_stats: bool):
    mask_suffix = _get_mask_suffix(args)
    lines_suffix = _get_lines_suffix(args)
    replace_suffix = _get_replace_suffix(args)
    
    # Subspace boosting
    subspace_boosting_coef = _get_subspace_boosting(args)


    # Model activation and stats logging
    model_stats_prefix = "model_stats_" if model_stats else ""

    # Execution time logging
    date_suffix = datetime.datetime.now().strftime("_%Y_%m_%d_%H:%S")

    filename = (
        f"{model_stats_prefix}{args.model}_{args.num_tasks}tasks_{args.method.full_name}_{subspace_boosting_coef}_nonlinear_additions_"
        f"{mask_suffix}{lines_suffix}{replace_suffix}{date_suffix}.json"
    )

    save_file = "results/merging/" + filename
    return save_file

def log_results(final_results, args, model_stats):
    save_file = _filename(args, model_stats)

    with open(save_file, "w") as f:
        json.dump(final_results, f, indent=4, cls=NumpyEncoder)


    # TODO: fix parse error in hydra and wandb logging
    hydra_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    hydra_save_file = f"{args.method.full_name}_nonlinear_additions.json"
    hydra_save_file = os.path.join(hydra_dir, hydra_save_file)
    json.dump(final_results, open(hydra_save_file, "w"), indent=4, cls=NumpyEncoder)

  

    print("saved results to: ", save_file)
    print("saved results to: ", hydra_save_file)
    artifact = wandb.Artifact(name="final_results", type="results")
    artifact.add_file(save_file)
    wandb.log_artifact(artifact)