import os
import shutil
import warnings

warnings.simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"

import math
from glob import glob

import numpy as np
import pytorch_lightning as pl
import torch
from .analysis_utils import plot_performance_curves, attack_result
from .data_utils import CustomDataModule
from .lightning_utils import CustomWriter, LightningQMIA
from .train_mia_ray import argparser
import argparse
from torch.utils.data import TensorDataset, DataLoader, Dataset



def accuracy_under_threshold():
    # get the position of specific threshold

    # get the binary result

    pass

def plot_model(
    args,
    checkpoint_path,
    fig_name="best",
    infer_data=None,
    infer_label=None,
    recompute_predictions=True,
    return_mean_logstd=False,
    mia_mode="eval"

):
    if return_mean_logstd:
        fig_name = "raw_{}".format(fig_name)
        prediction_output_dir = os.path.join(
            args.root_checkpoint_path,
            "raw_predictions",
            fig_name,
        )
    else:
        prediction_output_dir = os.path.join(
            args.root_checkpoint_path,
            "predictions",
            fig_name,
        )
    print("Saving predictions to", prediction_output_dir)

    os.makedirs(prediction_output_dir, exist_ok=True)

    if (
        recompute_predictions
        or len(glob(os.path.join(prediction_output_dir, "*.pt"))) == 0
    ):
        try:
            if os.environ["LOCAL_RANK"] == "0":
                shutil.rmtree(prediction_output_dir)
        except:
            pass
        # os.makedirs(prediction_output_dir, exist_ok=True)
        # Get model and data

        if not isinstance(infer_data, Dataset):
            infer_data = TensorDataset(
            infer_data, infer_label)

        infer_dataloader = DataLoader(
            infer_data,
            batch_size=200,
            shuffle=False,
        )

        datamodule = CustomDataModule(
            # train_dataset=args.shadow_dataloader.dataset,
            # test_dataset=args.shadow_dataloader.dataset,
            # val_dataset=args.shadow_dataloader.dataset,
            train_dataset=infer_dataloader.dataset,
            test_dataset=infer_dataloader.dataset,
            val_dataset=infer_dataloader.dataset,
            batch_size=200,
        )


        print("reloading from", checkpoint_path)

        lightning_model = LightningQMIA(
            architecture=args.architecture,
            base_architecture=args.base_architecture,
            image_size=args.image_size,
            hidden_dims=[512, 512],
            num_base_classes=args.num_classes,    
            freeze_embedding=False,
            low_quantile=args.low_quantile,
            high_quantile=args.high_quantile,
            n_quantile=args.n_quantile,
            use_logscale=args.use_log_quantile,
            # cumulative_qr=False,
            optimizer_params={"opt_type": args.opt},
            base_model_path=args.base_model_path,
            # base_model_path=os.path.join(
            #     args.model_root,
            #     args.dataset,
            #     "base",
            #     args.base_model_name_prefix,
            #     args.base_architecture,
            #     "model.pickle",
            # ),
            rearrange_on_predict=not args.use_gaussian,
            use_hinge_score=args.use_hinge_score,
            use_target_label=args.use_target_label,
            lr=args.lr,
            weight_decay=args.weight_decay,
            use_gaussian=args.use_gaussian,
            use_target_dependent_scoring=args.use_target_dependent_scoring,
            use_target_inputs=args.use_target_inputs,
            dataset=args.dataset,
        )

        # load checkpoint 
        checkpoint_info = torch.load(checkpoint_path)
        lightning_model.model.load_state_dict(checkpoint_info["model_state_dict"])

        # load base model
        if return_mean_logstd:
            lightning_model.return_mean_logstd = True

        pred_writer = CustomWriter(
            output_dir=prediction_output_dir, write_interval="epoch"
        )


        trainer = pl.Trainer(
            max_epochs=0,
            accelerator="auto" if torch.cuda.is_available() else "cpu",
            callbacks=[pred_writer],
            devices=1,
            enable_progress_bar=True,
        )

        predict_data = datamodule.predict_dataloader()

        predict_results =trainer.predict(
            lightning_model, dataloaders=predict_data, return_predictions=True
        )

        trainer.strategy.barrier()
        if trainer.global_rank != 0:
            return

    # Trainer predict in DDP does not return predictions. To use distributed predicting, we instead save the prediciton outputs to file then concatenate manually
    predict_results = None
    for file in glob(os.path.join(prediction_output_dir, "*.pt")):
        rank_predict_results = torch.load(file)
        if predict_results is None:
            predict_results = rank_predict_results
        else:
            for r, p in zip(rank_predict_results, predict_results):
                p.extend(r)

    def join_list_of_tuples(list_of_tuples):
        n_tuples = len(list_of_tuples[0])
        result = []
        for _ in range(n_tuples):
            try:
                result.append(torch.concat([t[_] for t in list_of_tuples]))
            except:
                result.append(torch.Tensor([t[_] for t in list_of_tuples]))
        return result
    
    (
        test_predicted_quantile_threshold,
        test_target_score,
        test_loss,
        test_base_acc1,
        test_base_acc5,
    ) = join_list_of_tuples(predict_results)

    if args.quantile_value>=0 and mia_mode=="attack":
        binary_res = (test_target_score >= test_predicted_quantile_threshold.T[args.quantile_value,:])
        threshold_list = test_predicted_quantile_threshold

    else:
        binary_res, scores, threshold_list = attack_result(test_target_score,
                                                test_predicted_quantile_threshold,
                                                use_logscale=args.use_log_quantile,
                                                low_quantile=args.low_quantile,
                                                high_quantile=args.high_quantile,
                                                n_quantile=args.n_quantile,
                                                quantile=args.quantile_value,
                                                mia_mode=mia_mode)
        
    # binary_res = tuple(float(x) for x in binary_res.tolist())
    return binary_res, test_target_score, test_predicted_quantile_threshold
    



    # plot_result = plot_performance_curves(
    #     np.asarray(private_target_score),
    #     np.asarray(test_target_score),
    #     private_predicted_score_thresholds=np.asarray(
    #         private_predicted_quantile_threshold
    #     ),
    #     public_predicted_score_thresholds=np.asarray(test_predicted_quantile_threshold),
    #     model_target_quantiles=model_target_quantiles,
    #     model_name="Quantile Regression",
    #     use_logscale=True,
    #     fontsize=12,
    #     savefig_path="./plots/{}/{}/{}/ray/use_hinge_{}/use_target_{}/{}.png".format(
    #         args.model_name_prefix + args.dataset,
    #         args.base_architecture.replace("/", "_"),
    #         args.architecture.replace("/", "_"),
    #         args.use_hinge_score,
    #         args.use_target_label,
    #         fig_name,
    #     ),
    # )
    # return plot_result


def parse_args_from_dict(arg_dict):
    namespace = argparse.Namespace()
    for key, value in arg_dict.items():
        setattr(namespace, key, value)
    
    return namespace

def infer_wrap(hyper_config, infer_data, infer_label,quantile_model_path,mia_mode):
    args = parse_args_from_dict(hyper_config)

    return plot_model(
        args,
        quantile_model_path,
        "best",
        # double check
        recompute_predictions=True,
        return_mean_logstd=args.return_mean_logstd,
        infer_data=infer_data,
        infer_label=infer_label,
        mia_mode=mia_mode,
    )



# if __name__ == "__main__":
#     args = argparser()
#     dst_checkpoint_path = os.path.join(args.root_checkpoint_path, "best_val_loss.ckpt")

#     # plot best trial
#     plot_model(
#         args,
#         dst_checkpoint_path,
#         "best",
#         recompute_predictions=False,
#         return_mean_logstd=args.return_mean_logstd,
#     )
