import numpy as np
import pandas as pd
from pathlib import Path
import sys
import hydra
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from omegaconf import OmegaConf
from weaver.dataset import VerificationDataset, ClusteringDataset
from weaver.models import Model
import wandb
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import warnings


FIGURES_DIR = Path("figures3")
FIGURES_DIR.mkdir(exist_ok=True)

def get_test_models_indices(data, model, fit_cfg):
    """
    Get the train model to be used for each test problem.

    fit_cfg.fit_type:
    - wclosest_to_train: use the closest train sample to the test sample.
    - search_weights: search for the best weights across all the train problems.

    If we fitted one model then do not do anything.
    """
    num_train_problems = len(data.train_data[0])
    num_test_problems = len(data.test_data[0])

    if num_train_problems == 0:
        if model.model_class in ["per_dataset_cluster"]:
            all_closest_train_idxs = model.clusters.find_test_set_clusters(data)
            return all_closest_train_idxs
        else:
            return [] * num_test_problems

    # If we fitted one model then do not do anything
    if model.model_class not in ["per_problem", "cluster", "per_dataset_cluster"]:
        return [] * num_test_problems

    # If the model is majority vote then do not do anything
    if model.model_type in ["majority_vote"]:
        return [] * num_test_problems

    best_train_indices = np.zeros(num_test_problems) * np.nan

    if model.model_class in ["cluster", "per_dataset_cluster"]:
        # Assign clusters to the test set and reorder the closest train idxs based on the clusters:
        all_closest_train_idxs = model.clusters.find_test_set_clusters(data)
    else:
        all_closest_train_idxs = data.closest_train_idxs

    # Use the closest train sample where the metric used was defined in data
    if fit_cfg.fit_type == "wclosest_to_train":
        for idx in range(num_test_problems):
            ranked_train_idxs = all_closest_train_idxs[idx]
            if model.model_class in ["per_problem", "cluster"]:
                for c_train_idx in ranked_train_idxs:
                    c_group_idx = model.problem_idx_to_group_idx(c_train_idx)
                    if model.is_trained[c_group_idx]:
                        break 
            else:
                assert model.is_trained
                c_train_idx = ranked_train_idxs[0]
            
            best_train_indices[idx] = c_train_idx

    elif fit_cfg.fit_type == "search_weights":
        # For each test set problem, find the best train problem and use its weight.
        num_train_problems = len(data.train_data[0])
        all_trained_models = np.array([model.is_trained[i] for i in range(num_train_problems)])
        trained_models_idxs = np.where(all_trained_models)[0]
        num_trained_models = len(trained_models_idxs)

        # If there are no trained models, then return an empty list
        if num_trained_models == 0:
            return [None] * num_test_problems
        all_train_indices = np.zeros((num_test_problems, num_trained_models)) * np.nan

        X_test, y_test = data.test_data

        for test_idx in range(num_test_problems):
            # for each train model:
            for train_idx in range(num_trained_models):
                problem_idx = trained_models_idxs[train_idx]
                outputs = model.calculate_metrics(X_test[test_idx], y_test[test_idx], problem_idx=problem_idx)
                all_train_indices[test_idx, train_idx] = outputs["top1_positive"]

        best_train_indices = np.argmax(all_train_indices, axis=1)
        best_train_indices = trained_models_idxs[best_train_indices]

    else:
        raise NotImplementedError(f"Unknown fit type: {fit_cfg.fit_type}")

    best_train_indices = best_train_indices.astype(int)
    return best_train_indices


def train_and_evaluate(data, model, fit_cfg):
    """Train and evaluate a model (either per_problem or per_dataset)."""
    all_results, all_test_results = [], []

    X_train, y_train = data.train_data
    X_test, y_test = data.test_data
    x_train_indices = data.train_idx
    x_test_indices = data.test_idx

    train_answers = data.train_answers
    test_answers = data.test_answers
    print(f"Number of train problems: {len(X_train)}", flush=True)
    print(f"Number of test problems: {len(X_test)}", flush=True)

    if model.model_type in ["majority_vote"]:
        X_train = train_answers
        X_test = test_answers

    # ---------------------------------------------------------------------------------------
    # Train model    
    num_train_problems = len(X_train)
    if num_train_problems > 0:
        if model.model_class in ["per_problem"]:
            print("Fitting model on each train problem.", flush=True)
            for idx in range(num_train_problems):
                print(f"Training model {idx} of {num_train_problems}", flush=True)
                X, y = X_train[idx], y_train[idx]
                sample_idx = data.train_idx[idx]
                if model.model_type in ["logistic_regression", "naive_bayes"] and len(np.unique(y)) == 1:
                    # do not train on problems with only one class
                    continue 
                model.fit(X, y, group_idx=idx)
        elif model.model_class in ["per_dataset",]:
            print("Fitting model on all train data", flush=True)
            model.fit(X_train, y_train)
        elif model.model_class == "per_dataset_cluster":
            print("Fitting model on all train data", flush=True)
            cluster_idxs = model.clusters.train_cluster_idxs
            model.fit(X_train, y_train, difficulties=cluster_idxs)
        elif model.model_class == "cluster":
            for idx in range(len(model.clusters.train_clusters)):
                cluster_idxs = model.clusters.train_clusters[idx]
                X, y = X_train[cluster_idxs], y_train[cluster_idxs]
                if model.model_type in ["logistic_regression", "naive_bayes"] and len(np.unique(y)) == 1:
                    # do not train on problems with only one class
                    continue
                original_cluster_idxs = x_train_indices[cluster_idxs]
                print(f"Fitting model over cluster {idx} using N={len(original_cluster_idxs)} problems.", flush=True)
                model.fit(X, y, group_idx=idx)

        else:
            raise NotImplementedError(f"Unknown model class: {model.model_class}")
    else:
        print("No train data to fit model", flush=True)

    # ---------------------------------------------------------------------------------------
    # Evaluate on train set
    print("Evaluating on train set...", flush=True)
    for idx in range(num_train_problems):
        sample_idx = data.train_idx[idx]
        problem_idx = idx if model.model_class in ["per_problem", "cluster"] else None

        # Handle different return values based on model type
        if model.model_class == "per_dataset_cluster":
            cluster_idxs = model.clusters.train_cluster_idxs[idx]
            outputs = model.calculate_metrics(X_train[idx], y_train[idx], difficulties=cluster_idxs)
        else:
            outputs = model.calculate_metrics(X_train[idx], y_train[idx], problem_idx=problem_idx)
            
        if np.isnan(outputs["top1_positive"]):
            continue 

        # Logging cluster if we are using a cluster model:
        if model.model_class == "cluster":
            outputs["cluster_id"] = model.problem_idx_to_group_idx(problem_idx)
        elif model.model_class == "per_dataset_cluster":
            outputs["cluster_id"] = cluster_idxs
        
        
        outputs["problem"] = sample_idx
        outputs["set"] = "train"
        outputs["difficulty"] = data.assignments[sample_idx]
        all_results.append(outputs)

    if len(all_results) == 0:
        all_results.append({"sample_accuracy": np.nan,
                       "top1_positive": np.nan,
                       "prediction_accuracy": np.nan,
                       "top1_tp": np.nan,
                       "top1_fp": np.nan,
                       "top1_tn": np.nan,
                       "top1_fn": np.nan,
                       "difficulty": np.nan})
        
    all_results = pd.DataFrame(all_results)

    # ---------------------------------------------------------------------------------------
    # Calculate the number of trained models:
    if model.model_class in ["per_problem", "cluster"]:
        num_trained_models = sum(model.is_trained.values())
    elif model.model_class in ["per_dataset", "per_dataset_cluster"]:
        num_trained_models = 1 if model.is_trained else 0
    else:
        raise NotImplementedError(f"Unknown model class: {model.model_class}")

    # ---------------------------------------------------------------------------------------

    # Evaluate on test set
    print("\n\nEvaluating on test set...", flush=True)
    test_model_indices = get_test_models_indices(data, model, fit_cfg)
    model.is_test = True

    # Fit once on all test data when using weak supervision:
    if model.model_type in ["weak_supervision", "unsupervised"]:
        print("\n\nFitting on test set", flush=True)
        train_not_test = not(float(data.train_split) == 1.0) or not(X_train.shape == X_test.shape) or np.any(X_train != X_test)
        if train_not_test:
            print(f"Test set differs from Train set, Fitting WS model on test set using N={len(X_test)} problems.", flush=True)
            if model.model_class == "per_dataset":
                model.model.is_test = True
                model.fit(X_test, y_test)
                # Set flag to not fit model again when calculating metrics
                model.fit_when_calculating_metrics = False
            elif model.model_class == "cluster":
                # Note: this will fit a model using all the data in the test set cluster.
                # it needs to be updated because at test time we may not have all of X_test available.
                # which may be used to drop verifiers, calculated balance, etc.
                for idx in range(len(model.clusters.test_clusters)):
                    print(f"Fitting model over cluster {idx} in test set using N={len(X_test[model.clusters.test_clusters[idx]])} problems.", flush=True)
                    cluster_idxs = model.clusters.test_clusters[idx]
                    X_test_tmp, y_test_tmp = X_test[cluster_idxs], y_test[cluster_idxs]
                    model.models[idx].is_test = True
                    model.fit(X_test_tmp, y_test_tmp, group_idx=idx)
                    # Do not fit when calculating metrics
                    model.models[idx].fit_when_calculating_metrics = False
            elif model.model_class == "per_dataset_cluster":
                cluster_idxs = model.clusters.test_cluster_idxs
                model.model.is_test = True
                model.fit(X_test, y_test, difficulties=cluster_idxs)
                # Having fit on the test set, do not fit when calculating metrics
                original_cluster_idxs = x_test_indices[cluster_idxs]
                print(f"Fitted model over clusters {np.unique(cluster_idxs)} using N={len(original_cluster_idxs)} problems.", flush=True)
                model.model.fit_when_calculating_metrics = False
            else:
                pass

    # ---------------------------------------------------------------------------------------
    print("\n\n Metrics on test set...", flush=True)
    num_test_problems = len(X_test)
    for idx in range(num_test_problems):
        sample_idx = data.test_idx[idx]

        ranked_train_idxs = data.closest_train_idxs[idx]
        if model.model_class in ["per_problem", "cluster"]:
            c_train_idx = test_model_indices[idx]
            if idx != c_train_idx and getattr(data, "same_train_test") and num_trained_models == num_test_problems:
                raise ValueError(f"Using train problem {c_train_idx} model for test problem: {idx}")
        elif model.model_class == "per_dataset_cluster":
            c_train_idx = []
        else:
            assert model.is_trained, "Model is not trained"
            c_train_idx = ranked_train_idxs[0]
        
        dist_ = data.distances[idx][c_train_idx]

        # closest problem in train set
        problem_idx = c_train_idx if model.model_class in ["per_problem", "cluster"] else None
 
        if model.model_class == "per_dataset_cluster":
            cluster_idxs = model.clusters.test_cluster_idxs[idx]
            outputs = model.calculate_metrics(X_test[idx], y_test[idx], difficulties=cluster_idxs)
        else:
            outputs = model.calculate_metrics(X_test[idx], y_test[idx], problem_idx=problem_idx)
        
        # Modify to capture additional return values for WS models
        if model.model_type == "weak_supervision":
            # Calculate class balance
            if y_test[idx] is not None:
                class_balance = np.mean(y_test[idx])
                outputs["class_balance"] = class_balance

        if model.model_class == "cluster":
            outputs["cluster_id"] = model.problem_idx_to_group_idx(problem_idx)
        elif model.model_class == "per_dataset_cluster":
            outputs["cluster_id"] = cluster_idxs

        # Log results
        outputs["problem"] = sample_idx
        outputs["set"] = "test"
        outputs["close_train_idx"] = problem_idx
        outputs["distance"] = dist_
        outputs["difficulty"] = data.assignments[sample_idx]
        all_test_results.append(outputs)

    all_test_results = pd.DataFrame(all_test_results)

    # assert no test values are nan
    assert not (all_test_results['top1_positive'].values == np.isnan).any()
    return all_results, all_test_results


def log_per_difficulty_results(df_train, df_test):
    """Logs per-difficulty results to console and Weights & Biases."""
    all_train_difficulty_levels = df_train["difficulty"].unique()
    all_test_difficulty_levels = df_test["difficulty"].unique()
    all_difficulty_levels = np.sort(np.unique(np.concatenate([all_train_difficulty_levels, all_test_difficulty_levels])))

    print(f"\nDifficulty levels: {all_difficulty_levels}", flush=True)
    for difficulty in all_difficulty_levels:
        df_train_diff = df_train[df_train["difficulty"] == difficulty]
        df_test_diff = df_test[df_test["difficulty"] == difficulty]

        train_acc = df_train_diff["sample_accuracy"].mean()
        test_acc = df_test_diff["sample_accuracy"].mean()
        train_select_acc = df_train_diff["top1_positive"].mean()
        test_select_acc = df_test_diff["top1_positive"].mean()

        try:
            num_train_problems = len(df_train_diff['top1_tp'])
            train_top1_acc = (df_train_diff["top1_tp"].sum() + df_train_diff["top1_tn"].sum()) / num_train_problems
            num_test_problems = len(df_test_diff['top1_tp'])
            test_top1_acc = (df_test_diff["top1_tp"].sum() + df_test_diff["top1_tn"].sum()) / num_test_problems
        except:
            train_top1_acc = np.nan
            test_top1_acc = np.nan

        print(f"Results for Difficulty Level: {difficulty}:", flush=True)
        print(f" Train Problems: {len(df_train_diff['sample_accuracy'])},\tSelect Acc.: {train_select_acc:.3f}," \
              f"\tSample Acc.: {train_acc:.3f}, \t Top1- Acc.: {train_top1_acc:.3f}", flush=True)
        print(f" Test Problems: {len(df_test_diff['sample_accuracy'])}, \tSelect Acc.: {test_select_acc:.3f}," \
              f"\tSample Acc.: {test_acc:.3f}, \t Top1- Acc.: {test_top1_acc:.3f}", flush=True)

        if wandb.run:
            wandb.log({
                f"epoch_train_accuracy_difficulty_{difficulty}": train_acc,
                f"epoch_test_accuracy_difficulty_{difficulty}": test_acc,
                f"epoch_train_select_accuracy_difficulty_{difficulty}": train_select_acc,
                f"epoch_test_select_accuracy_difficulty_{difficulty}": test_select_acc,
                f"epoch_train_samples_difficulty_{difficulty}": len(df_train_diff['sample_accuracy']),
                f"epoch_test_samples_difficulty_{difficulty}": len(df_test_diff['sample_accuracy']),
                f"epoch_train_top1_accuracy_difficulty_{difficulty}": train_top1_acc,
                f"epoch_test_top1_accuracy_difficulty_{difficulty}": test_top1_acc,
            })


@hydra.main(config_path="../configs", config_name="supervised", version_base=None)
def main(args) -> None:
    if args.get("debug", False):
        args.data_cfg.train_split = 0.2
        args.data_cfg.train_queries = 10 # number of queries to sample from train split
        args.data_cfg.train_samples = 10 # number of samples to sample from train split
        args.data_cfg.same_train_test = True
        args.logging = "none"

    if args.logging == "wandb":
        wandb.init(**args.wandb_cfg, config=OmegaConf.to_container(args, resolve=True))
    train(args)
    if args.logging == "wandb":
        wandb.finish()


def train(args):

    data = VerificationDataset(**args.data_cfg)

    # Print the number of train problems, number of samples, and number of verifiers
    print(f"Train problems: {data.train_data[0].shape[0]}, \
          Train samples: {data.train_data[0].shape[1]}, \
          Test problems: {data.test_data[0].shape[0]}, \
          Test samples: {data.test_data[0].shape[1]}, Verifiers: {len(data.verifier_names)}", flush=True)

    clusters = None
    if args.model_cfg.model_class == "cluster":
        clusters = ClusteringDataset(**args.model_cfg.cluster_cfg)
        clusters.compute_clusters(data, mode="train")
        num_models = len(clusters.train_clusters)
    elif args.model_cfg.model_class == "per_problem":
        num_models = len(data.train_data[0])
    elif args.model_cfg.model_class == "per_dataset":
        num_models = None
    elif args.model_cfg.model_class == "per_dataset_cluster":
        clusters = ClusteringDataset(**args.model_cfg.cluster_cfg)
        clusters.compute_clusters(data, mode="train")
        num_models = 1
    else:
        raise NotImplementedError(f"Unknown model class: {args.model_cfg.model_class}")

    if args.data_cfg.reward_threshold is not None:
        data.binarize_verifiers(clusters, split="train")
        data.binarize_verifiers(clusters, split="test")

    model = Model(data.verifier_names, clusters, **args.model_cfg, num_models=num_models)

    df_train, df_test = train_and_evaluate(data, model, args.fit_cfg)

    # Let's look at the Top1 results:
    try:
        top1_tp_train = df_train["top1_tp"].sum()
        top1_fp_train = df_train["top1_fp"].sum()
        top1_fn_train = df_train["top1_fn"].sum()
        top1_tn_train = df_train["top1_tn"].sum()
        top1_acc_train = (top1_tp_train + top1_tn_train) / (top1_tp_train + top1_tn_train + top1_fp_train + top1_fn_train)

        
        top1_tp_test = df_test["top1_tp"].sum()
        top1_fp_test = df_test["top1_fp"].sum()
        top1_fn_test = df_test["top1_fn"].sum()
        top1_tn_test = df_test["top1_tn"].sum()
        top1_acc_test = (top1_tp_test + top1_tn_test) / (top1_tp_test + top1_tn_test + top1_fp_test + top1_fn_test)
    except:
        top1_tp_train, top1_fp_train, top1_fn_train, top1_tn_train, top1_acc_train = np.nan, np.nan, np.nan, np.nan, np.nan
        top1_tp_test, top1_fp_test, top1_fn_test, top1_tn_test, top1_acc_test = np.nan, np.nan, np.nan, np.nan, np.nan
        
    print("\nOverall Results: Model: ", args.model_cfg.model_type, flush=True)
    print(f"Verifiers {len(data.verifier_names)}:", data.verifier_names, flush=True)
    print(f"\nTrain: N problems: {len(df_train['sample_accuracy'])},"  \
          f"\tSelect Acc.: {df_train['top1_positive'].mean():.3f}," \
          f"\tSample Acc.: {df_train['sample_accuracy'].mean():.3f}," \
          f"\tTop1-Acc.: {top1_acc_train:.3f}," \
          f"\n\tTop1-TP: {top1_tp_train}," \
          f"\tTop1-TN: {top1_tn_train}," \
          f"\tTop1-FP: {top1_fp_train}," \
          f"\tTop1-FN: {top1_fn_train}", flush=True)
    print(f"\nTest: N problems: {len(df_test['sample_accuracy'])}," \
          f"\tSelect Acc.: {df_test['top1_positive'].mean():.3f}," \
          f"\tSample Acc.: {df_test['sample_accuracy'].mean():.3f}," \
          f"\tTop1-Acc.: {top1_acc_test:.3f}," \
          f"\n\tTop1-TP: {top1_tp_test}," \
          f"\tTop1-TN: {top1_tn_test}," \
          f"\tTop1-FP: {top1_fp_test}," \
          f"\tTop1-FN: {top1_fn_test}", flush=True)

    if wandb.run:
        metrics_to_log = ["sample_accuracy", "top1_positive", "prediction_accuracy", \
                          "top1_tp", "difficulty"]
        for df, name in [(df_train, "train"), (df_test, "test")]:
            for key in df.columns:
                if key in metrics_to_log:
                    value = df[key].values
                    if key == "top1_positive":
                        key_name = "select_accuracy"
                    else:
                        key_name = key
                    if not isinstance(value, str):
                        value = value.mean()
                    wandb.log({
                        f"epoch_{name}_{key_name}": value
                    })
        # Also add f1 precision and recall
        wandb.log({
            f"epoch_train_top1_tp": top1_tp_train,
            f"epoch_train_top1_tn": top1_tn_train,
            f"epoch_train_top1_fp": top1_fp_train,
            f"epoch_train_top1_fn": top1_fn_train,
            f"epoch_test_top1_tp": top1_tp_test,
            f"epoch_test_top1_tn": top1_tn_test,
            f"epoch_test_top1_fp": top1_fp_test,
            f"epoch_test_top1_fn": top1_fn_test,
        })


    # Log per-cluster results:
    if model.model_class in ["cluster", "per_dataset_cluster"]:
        metrics_to_log = ["sample_accuracy", "top1_positive", "prediction_accuracy", \
                        "top1_tp", "difficulty", "cluster_id"]

        train_clusters =  np.unique( df_train['cluster_id'], return_counts=True)   
        test_clusters =  np.unique( df_test['cluster_id'], return_counts=True)
        print(f"Train clusters and problem counts: {train_clusters}, \n Test clusters and problem counts: {test_clusters}", flush=True)

        # Check if the model is all posite or all negative
        for df, name in [(df_train, "train"), (df_test, "test")]:
            for cluster_id in np.sort(list(model.clusters.train_clusters.keys())):
                df_cluster = df[df["cluster_id"] == cluster_id]
                if len(df_cluster) == 0:
                    print(f"No problems in cluster {cluster_id} for {name} set", flush=True)
                    continue

                num_problems = len(df_cluster)
                num_samples = df_cluster["num_samples"].sum() // num_problems

                samples_tp = df_cluster["sample_tp"].sum()
                samples_fp = df_cluster["sample_fp"].sum()
                samples_tn = df_cluster["sample_tn"].sum()
                samples_fn = df_cluster["sample_fn"].sum()

                top1_tp = df_cluster["top1_tp"].sum()
                top1_fp = df_cluster["top1_fp"].sum()
                top1_tn = df_cluster["top1_tn"].sum()
                top1_fn = df_cluster["top1_fn"].sum()

                samples_acc = (samples_tp + samples_tn) / (samples_tp + samples_tn + samples_fp + samples_fn)
                top1_acc = (top1_tp + top1_tn) / (top1_tp + top1_tn + top1_fp + top1_fn)
                print(f"\n{name}  Cluster {cluster_id}: N problems: {num_problems}, N samples: {num_samples}")
                print(f"  Samples: Acc. {samples_acc:.3f}, TP {samples_tp}, TN {samples_tn}, FP {samples_fp}, FN {samples_fn}")
                print(f"  Top1: Acc. {top1_acc:.3f}, TP {top1_tp}, TN {top1_tn}, FP {top1_fp}, FN {top1_fn}")

                if wandb.run:
                    wandb.log({
                        f"epoch_{name}_sample_tp/cluster_{cluster_id}": samples_tp,
                        f"epoch_{name}_sample_tn/cluster_{cluster_id}": samples_tn,
                        f"epoch_{name}_sample_fp/cluster_{cluster_id}": samples_fp,
                        f"epoch_{name}_sample_fn/cluster_{cluster_id}": samples_fn,
                        f"epoch_{name}_top1_tp/cluster_{cluster_id}": top1_tp,
                        f"epoch_{name}_top1_tn/cluster_{cluster_id}": top1_tn,
                        f"epoch_{name}_top1_fp/cluster_{cluster_id}": top1_fp,
                        f"epoch_{name}_top1_fn/cluster_{cluster_id}": top1_fn,
                    })

                # Check if the model is all posite or all negative
                for key in df_cluster.columns:
                    if key in metrics_to_log:
                        value = df_cluster[key].values
                    if key == "top1_positive":
                        key_name = "select_accuracy"
                    else:
                        key_name = key
                    if not isinstance(value, str):
                        value = value.mean()
                
                    if wandb.run:
                        wandb.log({
                            f"epoch_{name}_{key_name}/cluster_{cluster_id}": value
                        })
        
    if wandb.run:
        
        # Log verifiers:
        wandb.log({
            "verifiers": data.verifier_names
        })
    
        # Log artifacts df_train and df_test
        run_id = wandb.run.id
        train_file = FIGURES_DIR / run_id / f"df_train.csv"
        test_file = FIGURES_DIR / run_id / f"df_test.csv"

        train_file.parent.mkdir(parents=True, exist_ok=True)
        test_file.parent.mkdir(parents=True, exist_ok=True)

        df_train.to_csv(train_file, index=False)
        df_test.to_csv(test_file, index=False)

        artifact = wandb.Artifact(name="data_results", type="dataset")
        artifact.add_file(local_path=train_file, name="train")
        artifact.add_file(local_path=test_file, name="test")
        wandb.run.log_artifact(artifact)


    log_per_difficulty_results(df_train, df_test)


if __name__ == "__main__":
    main()