import argparse
import json
import logging
import pathlib
import warnings

import dotenv
import numpy as np
import torch

import eval.data as data
import eval.settings as settings
import eval.util as util
import eval.tasks as tasks

from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs

from tqdm import tqdm

np.set_printoptions(threshold=10000000)
CIFAR10_CLASSES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

def row_argmax(score_path: pathlib.Path, N: int, *, chunk: int = 10000) -> np.ndarray:
    score = torch.load(score_path).cpu().numpy()
    assert score.shape == (N, N)
    out = np.empty(N, dtype=np.int32)
    for start in range(0, N, chunk):
        end = min(N, start + chunk)
        out[start:end] = np.argmax(score[start:end, :], axis=1)
    del score
    return out

def canary_candidates(score_paths: list[pathlib.Path], N: int) -> np.ndarray:
    rank1 = row_argmax(score_paths[0], N)
    candidates = np.where(rank1 == np.arange(N))[0]
    for path in tqdm(score_paths[1:]):
        rank1 = row_argmax(path, N)
        candidates = np.intersect1d(candidates, np.where(rank1 == np.arange(N))[0], assume_unique=False)
        if candidates.size == 0:
            break
    return candidates

# def compute_metrics(score_path: pathlib.Path, candidate_indices: np.ndarray, N: int) -> np.ndarray:
#     score = torch.load(score_path).cpu().numpy()
#     out = np.empty(candidate_indices.shape[0], dtype=np.float32)
#     for i, global_idx in tqdm(enumerate(candidate_indices)):
#         row = score[global_idx, :]          # view into score
#         saved = row[global_idx]             # save the element to exclude
#         row[global_idx] = -np.inf           # temporarily remove it
#         out[i] = saved / (np.max(row) + 1e-12)
#         row[global_idx] = saved             # restore the original value
#     del score
#     return out

def compute_metrics(score_path: pathlib.Path,
                    candidate_indices: np.ndarray,
                    N: int) -> np.ndarray:
    
    score = torch.load(score_path).cpu().numpy()
    
    diag = np.diag(score).astype(np.float32, copy=True)
    np.fill_diagonal(score, -np.inf)

    # max of each row with the self entry removed
    row_max_excl = score.max(axis=1).astype(np.float32)

    # gather only what we need
    out = diag[candidate_indices] / (row_max_excl[candidate_indices] + 1e-12)

    del score
    return out

def main() -> None:
    dotenv.load_dotenv()
    args = parse_args()
    util.setup_logging()
    config_path = util.DirectoryManager.get_config_path(args.dir)
    logging.info("Using config from %s", config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    config = settings.Settings.model_validate_json(config_path.read_text())
    directory_manager = util.DirectoryManager(args.dir)

    logging.info("Base dataset: %s", config.base_dataset.name)
    logging.info("Model trainer: %s", config.model_trainer.trainer_type)

    # Load raw data
    dataset_loader = config.base_dataset.build_loader()
    dataset_loader.prepare_raw_data()
    train_images_full, train_targets_full = dataset_loader.load_train_data()
    data.validate_dataset(
        train_images_full,
        train_targets_full,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),
        num_classes=config.base_dataset.get_num_classes(),
    )
    val_images, val_targets = dataset_loader.load_val_data()
    data.validate_dataset(
        val_images,
        val_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=val_images.shape[0],
        num_classes=config.base_dataset.get_num_classes(),
    )

    # Move all raw data to GPU for efficiency
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_images = train_images_full.to(device)
    train_targets = train_targets_full.to(device)
    val_images = val_images.to(device)
    val_targets = val_targets.to(device)

    data.validate_dataset(
        train_images,
        train_targets,
        image_shape=config.base_dataset.get_image_shape(),
        num_samples=config.base_dataset.get_num_train_samples(),  # canaries replace original samples, hence size does not change
        num_classes=config.base_dataset.get_num_classes(),
    )

    model_trainer = config.model_trainer.build_trainer()
    model_trainer.images_mean_std = dataset_loader.dataset_mean_std
    device = torch.device("cuda")  # TODO: handle this differently?

    rng_training = np.random.default_rng(config.global_seed)
    training_seeds = rng_training.integers(0, 2**32, size=config.num_models_target + config.num_models_shadow)
    
    ############################################################
    # Training multiple models for canary selection
    ############################################################
    
    for model_idx in range(args.num_models):

        model_file = directory_manager.get_influence_model_file(model_idx)
        if model_file.exists():
            logging.info("Model %d already trained, skipping", model_idx)
            continue
        current_train_images, current_train_targets = data.build_train_data(
            train_images,
            train_targets,
            membership_mask=None,
        )

        logging.info("Training model %d with seed %d", model_idx, training_seeds[model_idx])
        current_model, current_aux = model_trainer.train(
            current_train_images,
            current_train_targets,
            seed=training_seeds[model_idx],
            device=device,
        )
        model_file.parent.mkdir(parents=True, exist_ok=True)
        
        clean_state_dict = {}
        for k, v in current_model.state_dict().items():
            new_key = k.replace("net_ema.", "")  # or k.split("net_ema.")[-1]
            clean_state_dict[new_key] = v
        
        torch.save(clean_state_dict, model_file)

        # Predict: on full base dataset, on all canaries, on validation set
        pred_base = model_trainer.predict(train_images, model=current_model, aux=current_aux)
        pred_val = model_trainer.predict(val_images, model=current_model, aux=current_aux)
        # FIXME: move to utility method?
        assert pred_base.shape == (train_images.shape[0], config.base_dataset.get_num_classes())
        assert pred_val.shape == (val_images.shape[0], config.base_dataset.get_num_classes())
        assert pred_base.dtype == torch.float32
        assert pred_val.dtype == torch.float32
        
        # Evaluate accuracies and losses
        accuracy_base, loss_base = evaluate_predictions(pred_base, train_targets)
        accuracy_val, loss_val = evaluate_predictions(pred_val, val_targets)
        
        logging.info("Model %d: Train accuracy %.4f, loss %.4f; Val accuracy %.4f, loss %.4f", model_idx, accuracy_base, loss_base, accuracy_val, loss_val)

    logging.info("Finished training all models")
    
    ############################################################
    # Run influence functions
    ############################################################
    
    dataset_mean, dataset_std = model_trainer.images_mean_std

    images_org = train_images.cuda()
    targets = train_targets.cuda()

    # Standardize images
    images = (images_org - dataset_mean.view(1, -1, 1, 1).to(device)) / dataset_std.view(1, -1, 1, 1).to(device)

    # Convert dataset to target precision now for the rest of the process....
    images = images.to(dtype=model_trainer.architecture.model_dtype).requires_grad_(False)
    if model_trainer.architecture.pad_amount > 0:
        warnings.warn('Using padding. If not wanted, set `hyp["net"]["pad_amount"] = 0`')
        ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
        images = model_trainer._pad_images(images, model_trainer.architecture.pad_amount)
        
    train_dataset = torch.utils.data.TensorDataset(images, targets)
    train_dataset_viz = torch.utils.data.TensorDataset(images_org, targets)

    for model_idx in range(args.num_models):
        
        logging.info("Analyzing model %d", model_idx)

        scores_file = directory_manager.get_influence_scores_file(model_idx)
        if scores_file.exists():
            logging.info("Scores for model %d already computed, skipping", model_idx)
            continue
        
        scores_file.parent.mkdir(parents=True, exist_ok=True)
        
        model_file = directory_manager.get_influence_model_file(model_idx)
        if not model_file.exists():
            raise FileNotFoundError(f"Model checkpoint not found at {model_file}")
        
        # Prepare the trained model.
        model = model_trainer.architecture.build_model(images, device)
        model.enable_eval_tta = False  # enable TTA for inference, but not for training
        model.load_state_dict(torch.load(model_file))
        
        # Define task and prepare model.
        task = tasks.task_selection(config.canary_task_selection)
        
        ##### Arielle TODO: Need to have a nicer way to regiser!
        
        from influence.model_fast import Conv as FastConv, Linear as FastLinear
        from torch import nn
        from kronfluence.module.tracked_module import TrackedModule
        TrackedModule.SUPPORTED_MODULES[FastConv]   = TrackedModule.SUPPORTED_MODULES[nn.Conv2d]
        TrackedModule.SUPPORTED_MODULES[FastLinear] = TrackedModule.SUPPORTED_MODULES[nn.Linear]

        model = prepare_model(model, task)
        
        influence_results_dir = directory_manager.get_influence_results_dir(model_idx)
        analyzer = Analyzer(
            analysis_name="cifar10",
            model=model,
            task=task,
            output_dir=influence_results_dir,
        )
        
        # Configure parameters for DataLoader.
        dataloader_kwargs = DataLoaderKwargs()
        analyzer.set_dataloader_kwargs(dataloader_kwargs)

        # Compute influence factors.
        factors_name = args.factor_strategy
        factor_args = FactorArguments(strategy=args.factor_strategy)
        # if args.use_half_precision:
        #     factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
        #     factors_name += "_half"
        analyzer.fit_all_factors(
            factors_name=factors_name,
            factor_args=factor_args,
            dataset=train_dataset,
            per_device_batch_size=None,
            overwrite_output_dir=False, ### IMPORTANT: turn off to save significant computation time
        )

        # Compute pairwise scores.
        score_args = ScoreArguments()
        scores_name = factor_args.strategy
        # if args.use_half_precision:
        #     score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
        #     scores_name += "_half"
        
        analyzer.compute_pairwise_scores(
            scores_name=scores_name,
            score_args=score_args,
            factors_name=factors_name,
            query_dataset=train_dataset,
            train_dataset=train_dataset,
            per_device_query_batch_size=args.query_batch_size,
            overwrite_output_dir=False, ### IMPORTANT: turn off to save significant computation time
        )
        
        scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
        logging.info(f"Scores shape: {scores.shape}")
        torch.save(scores, scores_file)
        # save_scores_to_npy(scores, scores_file, chunk=256)
        del scores
    
    ##### canary selection #####
    N = len(train_dataset)
    
    # load the scores
    score_paths = []
    for model_idx in range(args.num_models):
        p = directory_manager.get_influence_scores_file(model_idx)
        if p.exists():
            score_paths.append(p)

    # candidates = canary_candidates(score_paths, N)
    # if candidates.size == 0:
    #     print("Warning: no canaries found.")
    #     return
    # print(f"Found {len(candidates)} candidate canaries: {candidates}")
    candidates = np.arange(N)
    
    # 2) Compute top-2 ratio for the candidate rows in each model, average them
    metrics_mean = None
    for path in tqdm(score_paths):
        metric = compute_metrics(path, candidates, N).astype(np.float32)
        if metrics_mean is None:
            metrics_mean = metric
        else:
            metrics_mean += metric
    metrics_mean /= len(score_paths)
    
    # Rank candidates by the metric (ascending: larger ratio = sharper self-peak)
    order = np.argsort(metrics_mean)[::-1]
    topk_idx_global = candidates[order][:args.topk_canaries]
    topk_metrics = metrics_mean[order][:args.topk_canaries]
    
    canary_save_path = directory_manager._get_influence_dir() / "canary_selection.json"
    with open(canary_save_path, "w") as f:
        json.dump({
            "canary_indices": topk_idx_global.tolist(),
            "canary_metrics": topk_metrics.tolist(),
        }, f, indent=2)
        
    print(f"Saved canary selection results to {canary_save_path}")

    botk_idx_global = candidates[order][-args.topk_canaries:]
    botk_metrics = metrics_mean[order][-args.topk_canaries:]
    
    # DEBUG
    random_id = [21498, 24833, 3391]
    random_metrics = [metrics_mean[21498], metrics_mean[24833], metrics_mean[3391]]

    import matplotlib.pyplot as plt
    save_dir = directory_manager._get_influence_dir()
    
    for (index, metric) in zip(topk_idx_global, topk_metrics):
        plt.figure(figsize=(8, 8))
        img = train_dataset_viz[index][0]               # CxHxW
        img = img.detach().to(torch.float32).cpu()        # cast to float32 on CPU
        img = img.permute(1, 2, 0).numpy() # HxWxC
        label = train_dataset[index][1]
        plt.imshow(img)
        plt.title(f"Image {index} | label {CIFAR10_CLASSES[label]}({label}) | metric {metric:5f}")
        plt.axis('off')
        plt.savefig(f"{save_dir}/image_highIF_{index}_metric_{metric:5f}.png")
        plt.close()
    
    
    for (index, metric) in zip(botk_idx_global, botk_metrics):
        plt.figure(figsize=(8, 8))
        img = train_dataset_viz[index][0]               # CxHxW
        img = img.detach().to(torch.float32).cpu()        # cast to float32 on CPU
        img = img.permute(1, 2, 0).numpy()                # HxWxC
        label = train_dataset[index][1]
        plt.imshow(img)
        plt.title(f"Image {index} | label {CIFAR10_CLASSES[label]}({label}) | metric {metric:5f}")
        plt.axis('off')
        plt.savefig(f"{save_dir}/image_lowIF_{index}_metric_{metric:5f}.png")
        plt.close()
    
    # DEBUG
    for (index, metric) in zip(random_id, random_metrics):
        plt.figure(figsize=(8, 8))
        img = train_dataset_viz[index][0]               # CxHxW
        img = img.detach().to(torch.float32).cpu()        # cast to float32 on CPU
        img = img.permute(1, 2, 0).numpy()                # HxWxC
        label = train_dataset[index][1]
        plt.imshow(img)
        plt.title(f"Image {index} | label {CIFAR10_CLASSES[label]}({label}) | metric {metric:5f}")
        plt.axis('off')
        plt.savefig(f"{save_dir}/image_ID_{index}_metric_{metric:5f}.png")
        plt.close()
    
    for (index, metric) in zip(topk_idx_global, topk_metrics):
        path = score_paths[0]
        mmap = np.memmap(path, mode="r", dtype=np.float32, shape=(N, N))
        row = np.array(mmap[index, :], copy=True)  # copy only this row
        del mmap
        row.sort()
        row = row[::-1][:20]
        plt.figure(figsize=(8, 8))
        plt.title(f"Scores for high-IF metric={metric:5f})")
        plt.plot(range(len(row)), row)
        plt.xlabel("Rank")
        plt.ylabel("IF score")
        plt.savefig(f"{save_dir}/high-IF_dist_{index}_metric_{metric:5f}.png")
        plt.close()

    for (index, metric) in zip(botk_idx_global, botk_metrics):
        path = score_paths[0]
        mmap = np.memmap(path, mode="r", dtype=np.float32, shape=(N, N))
        row = np.array(mmap[index, :], copy=True)  # copy only this row
        del mmap
        row.sort()
        row = row[::-1][:20]
        plt.figure(figsize=(8, 4))
        plt.title(f"Scores for low-IF metric={metric:5f})")
        plt.plot(range(len(row)), row)
        plt.xlabel("Rank")
        plt.ylabel("IF score")
        plt.savefig(f"{save_dir}/low-IF_dist_{index}_metric_{metric:5f}.png")
        plt.close()

    # DEBUG
    for (index, metric) in zip(random_id, random_metrics):
        path = score_paths[0]
        mmap = np.memmap(path, mode="r", dtype=np.float32, shape=(N, N))
        row = np.array(mmap[index, :], copy=True)  # copy only this row
        del mmap
        row.sort()
        row = row[::-1][:20]
        plt.figure(figsize=(8, 4))
        plt.title(f"Scores for ID metric={metric:5f})")
        plt.plot(range(len(row)), row)
        plt.xlabel("Rank")
        plt.ylabel("IF score")
        plt.savefig(f"{save_dir}/ID_dist_{index}_metric_{metric:5f}.png")
        plt.close()

def evaluate_predictions(predictions: torch.Tensor, targets: torch.Tensor) -> tuple[float, float]:
    accuracy = torch.mean((predictions.argmax(dim=-1) == targets).float()).item()
    loss = torch.nn.functional.cross_entropy(predictions, targets).item()
    return accuracy, loss


def compute_metric(indices, scores):
    if len(indices) == 0:
        print(f"Warning no canary founded!")
        return [0], [0.0]
    else:
        scores = scores[indices]
        scores = np.sort(scores, axis=1)[:, ::-1]
        metrics = np.zeros(scores.shape[0])
        metrics[:] = scores[:,1] / scores[:,0]
    return metrics


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", type=pathlib.Path, required=True, help="Path to experiment base directory")
    parser.add_argument(
        "--num_models",
        type=int,
        required=True,
        help="Number of models to train for canary selection.",
    )
    parser.add_argument(
        "--factor_strategy",
        type=str,
        default="ekfac",
        help="Strategy to compute influence factors.",
    )
    parser.add_argument(
        "--query_batch_size",
        type=int,
        default=2000,
        help="Batch size for computing query gradients.",
    )
    # parser.add_argument(
    #     "--use_half_precision",
    #     action="store_true",
    #     default=False,
    #     help="Whether to use half precision for computing factors and scores.",
    # )
    parser.add_argument("--topk_canaries", type=int, default=10, help="Number of top canaries to visualize.")
    return parser.parse_args()


if __name__ == "__main__":
    main()
