from typing import Dict, List, Union
from pathlib import Path
from tqdm.auto import tqdm
from joblib import delayed
from jaxlib.xla_extension import DeviceArray

import os
import h5py
import argparse
import warnings
import traceback
import numpy as np
import pandas as pd
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
import jax

from estimators.linear_node import LinearNODE
from estimators.wrapper_sindy import LinearSINDyWithHyperparamSearch
from metrics import ESTIMATION_METRICS
from utils import ProgressParallel


def check_jax_device():
    devices = jax.devices()

    for device in devices:
        print(f"Device: {device}, Device Type: {device.device_kind}")

    print("Default JAX Device:", jax.default_backend())

check_jax_device()

ESTIMATORS = {
    "LinearSINDy": LinearSINDyWithHyperparamSearch,
    "LinearNODE": LinearNODE,
}

def standardize_trajectories(xs: np.ndarray) -> np.ndarray:
    """
    Standardize each trajectory along the time dimension.
    xs is expected to have shape: (n_systems, n_inits, n_vars, len_solution).
    For each (system, init, var), we subtract the mean and divide by the std along the time axis.
    """
    means = np.mean(xs, axis=-1, keepdims=True)
    stds = np.std(xs, axis=-1, keepdims=True)
    # Avoid division by zero
    stds[stds == 0] = 1
    return (xs - means) / stds

def _evaluate(
    data: np.ndarray,
    row_i: int,
    model, 
    metrics: List,
    skip_invalid_trajectories: bool,
    lambda_val: float = 0, 
    n_iters: int = 5000,
    scale_dim_val: bool = True

):
    estimator = ESTIMATORS[model]()
    try:
        if model == "LinearNODE":
            if skip_invalid_trajectories:
                assert np.all(np.isfinite(data.loc[row_i, "xs"]))
            estimator.fit(data.loc[row_i, "xs"].T, t=data.loc[row_i, "tt"], niters = n_iters, reg_lambda = lambda_val, scale_dim = scale_dim_val)
            estimate = estimator.get_system_matrix()
        elif model == "LinearSINDy":
            if skip_invalid_trajectories:
                assert np.all(np.isfinite(data.loc[row_i, "xs"]))
            estimator.fit(data.loc[row_i, "xs"].T, t=data.loc[row_i, "tt"])
            estimate = estimator.get_system_matrix()           
    except Exception as e:
        traceback.print_exc()
        estimate = np.nan
    scores = compute_metrics(estimate, data.loc[row_i, "A"], metrics)
    scores["estimate"] = estimate
    info_dict = estimator.get_info()
    for key in info_dict.keys():
        scores[key] = info_dict[key]
    return scores


def evaluate(
    data: Union[str, Dict],
    model: str,
    metrics: List[str],
    n_jobs: int = 4,
    max_num_samples: int = None,
    num_time_points: Union[int, None] = None,
    skip_invalid_trajectories: bool = True,
    lambda_val: float = 0, 
    niters: int = 5000,
    scale_dim_val: bool = True,
    standardize: bool = False  
) -> pd.DataFrame:
    if isinstance(data, str):
        data = load_data(data)
    assert isinstance(data, Dict), f"data should be of type Dict but has type: {type(data)}."
    assert len(data["xs"].shape) == 4, \
        f"Expected len(data['xs'].shape) == 4 but found len(data['xs'].shape) == {len(data['xs'].shape)}"
        
    if num_time_points is not None:
        data["xs"] = data["xs"][..., 0:num_time_points]
        data["tt"] = data["tt"][..., 0:num_time_points]
        print(f"Using initial {num_time_points} time points of each trajectory.")
        if data["xs"].shape[-1] < num_time_points:
            warnings.warn(
                f"Warning: num_time_points ({num_time_points}) > data['xs'].shape[-1] ({data['xs'].shape[-1]})"
            )
            
    if standardize:
        print("Standardizing trajectories...")
        data["xs"] = standardize_trajectories(data["xs"])
        
    # original_keys = data.keys()
    data = dict_to_DataFrame(data)
    data.loc[:, "model"] = model
    num_samples = data.shape[0]
    if max_num_samples is not None:
        num_samples = min(max_num_samples, num_samples)
        print(f"Evaluation limited to {num_samples} samples.")
    if n_jobs > 1:
        list_of_scores = ProgressParallel(n_jobs=n_jobs, total=num_samples)(
            delayed(_evaluate)(data, row_i, model, metrics, skip_invalid_trajectories, lambda_val, niters, scale_dim_val)
            for row_i in range(num_samples)
        )
    else:
        list_of_scores = [
            _evaluate(data, row_i, model, metrics, skip_invalid_trajectories, lambda_val, niters, scale_dim_val)
            for row_i in tqdm(range(num_samples))
        ]
    print("Retrieving results:")
    for row_i in tqdm(range(num_samples)):
        scores = list_of_scores[row_i]
        for key in scores.keys():
            if isinstance(scores[key], np.ndarray):
                # confusion matrix
                data.loc[row_i, key] = [scores[key].astype(object)]
            elif isinstance(scores[key], Dict):
                data.loc[row_i, key] = [scores[key]]
            else:
                data.loc[row_i, key] = scores[key]
    return data


def dict_to_DataFrame(data: Dict) -> pd.DataFrame:
    new_data = dict()
    n_systems, n_inits, n_vars, len_solution = data["xs"].shape
    
    # Limit n_inits to 10
    n_inits = min(n_inits, 10)
    
    for key in data.keys():    
        if key in ["sigma_xx", "xs"]:
            # shape: num_systems, num_x0s, n_vars, n_vars
            # or 
            # shape: num_systems, num_x0s, n_vars, len_solution
            new_data[key] = data[key][:, :n_inits].reshape(n_systems * n_inits, n_vars, -1)
        elif key == "x0":
            # shape: num_systems, n_vars
            new_data[key] = data[key]
        elif key == "tt":
            # shape: len(solution), 
            new_data[key] = data[key].reshape(1, -1).repeat(n_systems * n_inits, axis=0)
        else:
            # shape: num_systems, 
            # or 
            # shape: num_systems, n_vars, n_vars
            new_data[key] = data[key].repeat(n_inits, axis=0)
    
    sorted_keys = sorted(new_data.keys())
    df = pd.DataFrame(columns=sorted_keys)
    for row_i in range(n_systems * n_inits):
        row = []
        for key in sorted_keys:
            row.append(new_data[key][row_i])
        df.loc[row_i] = row
    return df


def compute_metrics(prediction: np.ndarray, ground_truth: np.array, metrics: List) -> Dict:
    scores = {}
    for metric in metrics:
        try:
            assert not np.any(np.isnan(prediction)), "Found np.nan in prediction."
            score = ESTIMATION_METRICS[metric](prediction, ground_truth)
            if isinstance(score, DeviceArray):
                if score.shape == ():
                    score = float(score)
                else:
                    score = np.array(score)
            scores[metric] = score
        except Exception as e:
            traceback.print_exc()
            scores[metric] = np.nan
    return scores
    
    
def load_data_old(path: str, keys: Union[None, List[str]]=None) -> Dict:
    data = dict()
    with h5py.File(path, 'r') as f:
        if keys is None:
            keys = f.keys()
        for key in keys:
            print(key)
            data[key] = f[key][:]
    return data

def load_data(path: str, keys: Union[None, List[str]] = None) -> Dict:
    # Define the allowed keys
    allowed_keys = ['A', 'xs', 'tt']
    data = dict()
    with h5py.File(path, 'r') as f:
        keys_to_load = allowed_keys if keys is None else [key for key in keys if key in allowed_keys]
        for key in keys_to_load:
            dset = f[key]
            if dset.shape == ():
                data[key] = dset[()]
            else:
                data[key] = dset[:]
    return data

        
        
def save_results(scores: pd.DataFrame, path: str, final: bool) -> None:
    if not final:
        path = path + ".intermediate"
    result_path = path + ".scores.pkl"
    print(f"Saving results under {result_path}")
    scores.to_pickle(result_path)
    return


def main(args) -> None:
    scores = evaluate(
        data=args.path,
        model=args.model,
        metrics=args.metrics,
        n_jobs=args.n_jobs,
        max_num_samples=args.max_num_samples,
        num_time_points=args.num_time_points,
        skip_invalid_trajectories= args.skip_invalid_trajectories,
        lambda_val = args.lambda_val, 
        niters =  args.niters,
        scale_dim_val = args.scale_dim_val,
        standardize=args.standardize  
    )
    save_results(scores=scores, path=args.result_path, final=True)
    return
        

if __name__ == "__main__":
    
    def list_or_none(arg) -> Union[None, List]:
        if isinstance(arg, List):
            return arg
        if arg is None:
            return None
        raise ValueError(f"Invalid type: {type(arg)}")
    
    def list_or_str(arg) -> Union[None, List]:
        if isinstance(arg, List):
            return arg
        elif isinstance(arg, str):
            return arg
        raise ValueError(f"Invalid type: {type(arg)}")
    
    def int_or_none(arg) -> Union[None, int]:
        if isinstance(arg, int):
            return arg
        elif arg is None or arg.lower() == "none":
            return None
        raise ValueError(f"Invalid type: {type(arg)}")
   
    
    def int_or_str_or_none(arg) -> Union[None, int]:
        if isinstance(arg, int):
            return arg
        elif arg is None or arg.lower() == "none":
            return None
        elif isinstance(arg, str):
            return int(arg)
        
        raise ValueError(f"Invalid type: {type(arg)}")
    
    def float_or_str_or_none(arg) -> Union[None, float]:
        
        if isinstance(arg, (float, int)):
            return float(arg)
        elif arg is None or arg.lower() == "none":
            return None
        elif isinstance(arg, str):
            return float(arg)
        
        raise ValueError(f"Invalid type: {type(arg)}")
    def str_or_bool(arg) ->  bool:
        if isinstance(arg, bool):
            return arg
        elif isinstance(arg, str) and arg.lower() == "true":
            return True
        elif isinstance(arg, str) and arg.lower() == "false":
            return False
        raise ValueError(f"arg {arg} should be of type str or bool but is type(arg) = {type(arg)}.")


    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str,
        default="..."
    )
    parser.add_argument("--model", type=str,
        choices=["LinearSINDy", "LinearNODE"]
    )
    parser.add_argument("--metrics", type=list_or_none, nargs="+", default=None, help="Metrics to use. None implies 'all available metrics'.")
    parser.add_argument("--n_jobs", type=int_or_str_or_none, default=None, help="Number of parallel jobs.")
    parser.add_argument("--max_num_samples", type=int_or_str_or_none, default=None)
    parser.add_argument("--model_dir", type=str, default=".")
    parser.add_argument("--num_time_points", type=int_or_str_or_none, default=None)
    parser.add_argument("--skip_invalid_trajectories", type=str_or_bool, default=False)
    parser.add_argument("--lambda_val", type=float_or_str_or_none, default=0)
    parser.add_argument("--niters", type=int_or_str_or_none, default=5000)
    parser.add_argument("--scale_dim_val", type=str_or_bool, default=False)
    parser.add_argument("--standardize", type=str_or_bool, default=False, help="Whether to standardize trajectories.")

    args = parser.parse_args()
    args.result_path = str(Path(args.model_dir) / f"{args.model}_{str(Path(args.path).name)}")
    
    os.makedirs(args.result_path, exist_ok=True)
    
    if args.n_jobs is None:
        if args.model == "LinearSINDy":
            args.n_jobs = 1
        elif args.model == "LinearNODE":
            args.n_jobs = 6
        else:
            args.n_jobs = 6
        print(f"Automatically setting args.n_jobs = {args.n_jobs}.")
    
    if not isinstance(args.metrics, List):
        if args.metrics is None:
            args.metrics = list(ESTIMATION_METRICS.keys())
        else:
            args.metrics = [args.metrics]
            
    for metric in args.metrics:
        assert metric in ESTIMATION_METRICS.keys(), \
            f"metric {metric} not found in ESTIMATION_METRICS: {ESTIMATION_METRICS.keys()}"
    
    assert args.model in ESTIMATORS.keys(), \
        f"model {args.model} not found in ESTIMATORS: {ESTIMATORS.keys()}"
        
    print(args)
    main(args)