
import os
import sys
sys.path.insert(0, './')
import json
import time
import pickle
import numpy as np
import math
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from MIA.quantile_utils import LightningQMIA
from MIA.MIA import MIA

import ray
from ray.air.config import CheckpointConfig
from ray.tune import CLIReporter

from ray.tune.integration.pytorch_lightning import (
    TuneReportCheckpointCallback,
)  # TuneReportCallback,
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

from MIA.quantile_utils import (
    CustomWriter,
    LightningQMIA,
    CustomDataModule,
    plot_performance_curves,
)

from torch.utils.data import ConcatDataset, DataLoader, TensorDataset, Subset

from MIA.QuantileMIA.image_QMIA.train_mia import quantile_mia_fit_wrap
from MIA.QuantileMIA.image_QMIA.plot_results import infer_wrap

import ray.tune as tune
from glob import glob

import pytorch_lightning as pl
from sklearn import metrics

NUM_CPUS_PER_GPU = 8
GPUS_PER_TRIAL = 1  # Current code does not support multi gpu per trial
NUM_CONCURRENT_TRIALS = 8


class QMIAModel(MIA):
    def __init__(self, name="QMIA", threshold=0.5, metric=None, mia_mode="attack",
                 low_quantile=-4, high_quantile=0, n_quantile=200, use_logscale=True,
                 hidden_dims=[], learning_rate=1e-4, weight_decay=1e-4, 
                 epochs=30, batch_size=128, device=None,image_size=224,num_classes=10,
                 low_lr=1e-6, high_lr=1e-3, dataset="cifar10", base_arch="resnet",base_model_path=None,
                 base_name="resnet",model_root="./results",load_epoch=None,total_epoch=None, exp_name=None):
        
        super().__init__(name, threshold, metric, mia_mode)

        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device

        self.mia_mode = mia_mode

        if mia_mode == "attack":
            self.attack = True
        else:
            self.attack = False 
        
        if load_epoch == total_epoch:
            suffix = f'{base_name}.ckpt'
        else:
            suffix = f'{base_name}_{load_epoch}.ckpt'

        base_epochs = load_epoch
        base_model_path = os.path.join(base_model_path,suffix)

    # input hyper-parameters
        self.low_lr = low_lr
        self.high_lr = high_lr
        self.weight_decay = weight_decay
        self.epochs = epochs

        self.hidden_dims = hidden_dims
    

        self.prediction_output_dir = model_root
        self.root_checkpoint_path = os.path.join(
            self.prediction_output_dir, dataset, base_arch
        )


    # fixed hyper-parameters
        self.architecture = 'facebook/convnext-tiny-224'
        self.opt = "adamw"
        self.metric = "ptl/val_loss"
        self.mode = "min"
        self.num_tune_samples = 20
        self.return_mean_logstd=False
        self.use_gaussian = True
        self.use_hinge_score = True
        self.use_target_label = False
        self.use_target_inputs = False
        self.use_target_dependent_scoring = False
        self.hyper_tune = False  # if True, will use ray tune to find the best hyper-parameters
        self.model_root = model_root
        self.batch_size = batch_size

    # dataloader:
        self.shadow_datamodule = None
        self.dataset = dataset  # default dataset

        self.base_architecture = base_arch # default base architecture
        self.base_model_path = base_model_path
        self.base_epochs = base_epochs
        self.exp_name = exp_name



        self.hyper_config = {
            "architecture": self.architecture, 
            "base_architecture": self.base_architecture, 
            "dataset": self.dataset,
            "epochs": self.epochs,
            "batch_size": self.batch_size,
            "mode": self.mode,
            "metric": self.metric,
            "image_size": image_size, 
            "tune_batch_size": False,  # TODO: fix this
            "model_root": self.model_root,
            "num_base_classes": num_classes, 
            "low_quantile": low_quantile, 
            "high_quantile": high_quantile, 
            "n_quantile": n_quantile, 
            "use_log_quantile": use_logscale, 
            "opt": self.opt , 
            "lr": 0.00001,
            "scheduler": "cosine",  # TODO: fix this
            "base_model_path": self.base_model_path ,
            "model_name_prefix": "bespoke",
            "use_gaussian": True, 
            "use_hinge_score": True, 
            "use_target_label": False, 
            "use_target_inputs": False,
            "use_target_dependent_scoring": False, 
            "grad_clip": 1.0,      
            "weight_decay": self.weight_decay,
            "root_checkpoint_path": self.root_checkpoint_path,
            # "checkpoint_path": "./results/quantila_model/quantile_mia/best_val_loss.ckpt",
            "return_mean_logstd": self.return_mean_logstd,
            "low_lr": self.low_lr,
            "high_lr": self.high_lr,
            "hidden_dims": self.hidden_dims,
            "num_tune_samples": self.num_tune_samples,
            "hyper_tune": self.hyper_tune,
            "num_gpus": GPUS_PER_TRIAL,
            "num_cpus": NUM_CPUS_PER_GPU,
            "num_shadow_models": 5,  # default number of shadow models
            "num_concurrent_trials": NUM_CONCURRENT_TRIALS,
            "shadow_dataloader": None,
            "exp_name": exp_name,
            "base_epochs": base_epochs,
            "mia_mode": mia_mode,
            "num_classes": num_classes
        }

        # best model 
        self.best_model_path = None

        self.score = None
        self.label = None
        self.scores_threshold = None
        self.count = 1

        if use_logscale:
            quantile_list = torch.sort(
                    1
                    - torch.logspace(
                        low_quantile, high_quantile, n_quantile, requires_grad=False
                    )
                )[0].reshape([1, -1])
        else:
            quantile_list = torch.sort(
                    torch.linspace(
                        low_quantile, high_quantile, n_quantile, requires_grad=False
                    )
            )[0].reshape([1, -1])
        self.quantile_list = quantile_list
        if self.mia_mode == "attack":
            self.hyper_config["quantile_value"] = -1
        else:
            self.hyper_config["quantile_value"] = 0.95

    def fit(self, model, fit_data_loaders, **kwargs):


        # train_data_generator and test_data_generator are used for fit
        
        if self.attack:
            
            train_data_generator =  fit_data_loaders["shadow_member"][0]
            test_data_generator = fit_data_loaders["shadow_nonmember"][0]
        
        else:
            train_data_generator =  fit_data_loaders["member_train"]
            test_data_generator = fit_data_loaders["nonmember_train"]

        # use shadow model for training quantile model?
        self.hyper_config["shadow_dataloader"] = train_data_generator
        if not self.hyper_tune:
            self.quantile_model_path = os.path.join(
                self.root_checkpoint_path, f"best_val_loss_{self.base_architecture}_{self.base_epochs}_{self.mia_mode}.ckpt"
            )
            if os.path.exists(self.quantile_model_path ):
                pass
            
            else: 
                quantile_mia_fit_wrap(
                self.hyper_config, test_data_generator, shadow_test_generator=train_data_generator
            )

        if self.attack:
            eval_train_generator = fit_data_loaders["shadow_member"][1]
            eval_test_generator = fit_data_loaders["shadow_nonmember"][1]
            
            private_idx = min(len(eval_train_generator.dataset),len(eval_test_generator.dataset))

            subset_indices = list(range(private_idx))  
            subset_train_dataset = Subset(eval_train_generator.dataset, subset_indices)

            dataset = ConcatDataset([subset_train_dataset,eval_test_generator.dataset])

            subset_train_labels = np.array(eval_train_generator.dataset.targets)[subset_indices]
            labels = np.concatenate((subset_train_labels, eval_test_generator.dataset.targets))

            binary_result, scores, scores_threshold =  infer_wrap(self.hyper_config, dataset, labels, self.quantile_model_path,mia_mode=self.mia_mode)

            private_target_scores = scores[:private_idx]
            public_target_scores = scores[private_idx:]

            private_thresholds = scores_threshold[:private_idx]
            public_thresholds = scores_threshold[private_idx:]
            
            member_idx = torch.arange(private_idx)
            nonmember_idx = torch.arange(private_idx,len(dataset))

            result = self.get_auc_related(private_target_scores,public_target_scores,
                        private_thresholds,public_thresholds,
                        member_idx, nonmember_idx,get_best=True)
            
            # self.hyper_config["quantile_value"] = self.quantile_list.squeeze(0)[result["quantile_list_idx"]]
            self.hyper_config["quantile_value"] = result["quantile_list_idx"]
            
            




            # return self.get_auc_related(
            #     private_thresholds=member_threshold,
            #     private_target_scores=member_score,
            #     public_thresholds=nonmember_threshold,
            #     public_target_scores=nonmember_score,
            #     member_idx=member_idx,
            #     nonmember_idx=nonmember_idx
            # )
                
        

        



    def get_auc_related(self,private_target_scores,public_target_scores,
                        private_thresholds,public_thresholds,
                        member_idx, nonmember_idx, get_best=False, quantile_idx=0):
       
        """

        """
        from sklearn.metrics import auc 
        

        # thresholds are either [n,n_thresholds] or [1,n_thresholds] depending on if the threshold is sample dependent or not

        prior = 0.0

        true_positives = (private_target_scores.reshape([-1, 1]) >= private_thresholds).sum(
            0
        ) + prior
        false_negatives = (private_target_scores.reshape([-1, 1]) < private_thresholds).sum(
            0
        ) + prior
        true_negatives = (public_target_scores.reshape([-1, 1]) < public_thresholds).sum(
            0
        ) + prior
        false_positives = (public_target_scores.reshape([-1, 1]) >= public_thresholds).sum(
            0
        ) + prior

        tpr  = np.nan_to_num(
            true_positives / (true_positives + false_negatives)
        )
        tnr = np.nan_to_num(
            true_negatives / (true_negatives + false_positives)
        )
        prec = np.nan_to_num(
            tpr / (tpr + 1 - tnr)
        )
        
        accs = (true_positives + true_negatives) / (true_positives + true_negatives + false_negatives + false_positives)

        best_acc_idx = np.argmax(accs)

        if not get_best:

            return {
                "auc": None,
                "best_accuracy": accs[quantile_idx],
                "predict": result,
                "member_pred": member_pred,
                "nonmember_pred": nonmember_pred,
                "tpr01fpr": tpr_fpr(0.001),
                "tpr001fpr": tpr_fpr(0.0001),
                "tp": true_positives[quantile_idx],
                "fn": false_negatives[quantile_idx],
                "tn": true_negatives[quantile_idx],
                "fp": false_positives[quantile_idx],

            }
                

        else:

            accs = (true_positives + true_negatives) / (true_positives + true_negatives + false_negatives + false_positives)
            best_acc_idx = np.argmax(accs)

            fpr = 1 - tnr
            order = np.argsort(fpr)
            fpr_sorted = fpr[order]
            tpr_sorted = tpr[order]
            roc_auc = auc(fpr_sorted, tpr_sorted)  # the same as  np.trapz(tpr_sorted, fpr_sorted)
            
            # get the best threshold for largest acc



            def tpr_fpr(fpr_target):
                if fpr_target <= fpr_sorted[0]:
                    tpr_at_fpr = tpr_sorted[0]
                elif fpr_target >= fpr_sorted[-1]:
                    tpr_at_fpr = tpr_sorted[-1]
                else:
                    tpr_at_fpr = float(np.interp(fpr_target, fpr_sorted, tpr_sorted))
                return tpr_at_fpr


            
            member_pred_idx = torch.where(private_target_scores.reshape([-1, 1]).squeeze(-1) >= private_thresholds[:,best_acc_idx])
            # member_pred_idx =  torch.where(private_target_scores >= private_thresholds)
            member_pred = torch.zeros(len(member_idx))
            member_pred[member_pred_idx] = 1
            

            nonmember_pred_idx = torch.where(public_target_scores.reshape([-1, 1]).squeeze(-1) >= public_thresholds[:,best_acc_idx])
            # nonmember_pred_idx =  torch.where(public_target_scores >= public_thresholds)
            nonmember_pred = torch.zeros(len(nonmember_idx))
            nonmember_pred[nonmember_pred_idx] = 1
            

            result = torch.ones(len(member_idx) + len(nonmember_idx))
            result[member_idx] = member_pred
            result[nonmember_idx] = nonmember_pred


            
            return  {
                "auc": roc_auc,
                "best_accuracy": accs[best_acc_idx],
                "predict": result,
                "member_pred": member_pred,
                "nonmember_pred": nonmember_pred,
                "tpr01fpr": tpr_fpr(0.001),
                "tpr001fpr": tpr_fpr(0.0001),
                "tp": true_positives[best_acc_idx],
                "fn": false_negatives[best_acc_idx],
                "tn": true_negatives[best_acc_idx],
                "fp": false_positives[best_acc_idx],
                "quantile_list_idx": best_acc_idx,

            }
            


    def infer(self, model, data, label):
        if self.mia_mode == "attack":
            binary_result, scores, scores_threshold =  infer_wrap(self.hyper_config, data, label, self.quantile_model_path,self.mia_mode)
            from sklearn import metrics
            label = [1] * int(len(binary_result)/2) + [0] * int(len(binary_result)/2) 
            print(metrics.accuracy_score(label,binary_result))
            return binary_result, scores
        else:
            _, scores, scores_threshold =  infer_wrap(self.hyper_config, data, label, self.quantile_model_path,self.mia_mode)
        
        batch_size = len(data)
        m_nm_size = batch_size // 2

        # get member/non-member labels

        m_nm_labels = torch.tensor([torch.tensor(1)] * m_nm_size + [torch.tensor(0)] * m_nm_size)

        if self.label is None:
            m_nm_total = m_nm_labels
        else:
            m_nm_total = torch.cat([self.label, m_nm_labels])

        self.label = m_nm_total

        if self.score is None:
            self.score = scores
        else:
            self.score = torch.cat([torch.atleast_1d(self.score),torch.atleast_1d(scores)])
        
        if self.scores_threshold is None:
            self.scores_threshold = scores_threshold
        else:
            self.scores_threshold = torch.cat([torch.atleast_1d(self.scores_threshold), torch.atleast_1d(scores_threshold)])
        
        

        self.count += 1
        return None, None
        
    def output(self):
        # get ftp and fnp for the current batch
        member_idx = torch.where(self.label == 1)[0]
        nonmember_idx = torch.where(self.label == 0)[0]
        member_score = self.score[member_idx]
        nonmember_score = self.score[nonmember_idx]

        member_threshold = self.scores_threshold[member_idx]
        nonmember_threshold = self.scores_threshold[nonmember_idx]


        if self.mia_mode == "attack":

            return self.get_auc_related(
                private_thresholds=member_threshold,
                private_target_scores=member_score,
                public_thresholds=nonmember_threshold,
                public_target_scores=nonmember_score,
                member_idx=member_idx,
                nonmember_idx=nonmember_idx,
                get_best=False,
                quantile_idx=self.hyper_config["quantile_value"]
            )


        else:

            return self.get_auc_related(
                private_thresholds=member_threshold,
                private_target_scores=member_score,
                public_thresholds=nonmember_threshold,
                public_target_scores=nonmember_score,
                member_idx=member_idx,
                nonmember_idx=nonmember_idx,
                get_best=True,

            )


        