from argparse import ArgumentParser
from copy import deepcopy
from my_utils.utils import ABLATION_STUDY_IMP_MODEL_TYPES, DATASET_NAMES, DELTAS, IMP_MODEL_TYPES, LEN_CLIPS, PRE_MODEL_TYPES, MASK_LENGTH
from my_utils.utils import load_pkl
from tqdm import tqdm
import pandas as pd
import numpy as np



def get_args():
    arg_parser = ArgumentParser()

    arg_parser.add_argument("--attack_index", type=int, default=-1)
    arg_parser.add_argument("--delta", type=float, default=100, help='the length of interval')
    
    args = arg_parser.parse_args()
    return args


class CheckPredictionOfImputation():
    """
    Input:
        ans_path: the pkl path of result of impute_then_predict
    """
    def __init__(self,ans,attack_index=-1,num_class=3,print=False):
        self.ans = ans
        self.num_sample = len(self.ans)
        self.num_mask = self.ans[0]['imputation'].shape[0]
        self.num_class = num_class
        self.attack_index = attack_index
    
    def run(self):
        """
        Return:
            acc
            indices of classification error sample
        """
        acc = 0
        errors = []
        mse = 0.0
        mae = 0.0
        for sample in tqdm(range(self.num_sample)):
            high = max(self.ans[sample]['imputation'].max(), self.ans[sample]['context'].max(), self.ans[sample]['target'].max(), self.ans[sample]['prediction'].max())
            low = min(self.ans[sample]['imputation'].min(), self.ans[sample]['context'].min(), self.ans[sample]['target'].min(), self.ans[sample]['prediction'].min())
            record = {c:0 for c in range(self.num_class)}

            # 真正的统计量，看看target落在哪个class
            target = self.ans[sample]['target'][0,self.attack_index]
            target_class = None
            starts = np.linspace(low,high,self.num_class+1)
            len_starts = len(starts)
            for c,start in enumerate(starts[0:-1]):
                if c == len_starts - 2:
                    if start <= target <= starts[c+1]:
                        target_class = c
                else:
                    if start <= target < starts[c+1]:
                        target_class = c

            # 看看补全然后预测的统计量落在哪个class
            for mask in range(self.num_mask):
                statistic = self.ans[sample]['prediction'][mask,self.attack_index]
                for c,start in enumerate(starts[0:-1]):
                    if c == len_starts - 2:
                        if start <= statistic <= starts[c+1]:
                            record[c] += 1
                    else:
                        if start <= statistic < starts[c+1]:
                            record[c] += 1
            
            flag = tmp_max = 0
            most_class = None
            for c,n in record.items():
                if n != 0:
                    flag += 1
                    # 统计样本最多的那个区间，然后看看有多少个样本落在这个区间
                    if n > tmp_max:
                        most_class = c
                        tmp_max = n

            if flag == 1 and most_class == target_class:
                acc += 1
                mae += abs(target - (starts[most_class]+starts[most_class+1])/2).item()
                mse += ((target - (starts[most_class]+starts[most_class+1])/2)**2).item()
            else:
                errors.append(sample)

        if acc > 0:
            mse /= acc
            mae /= acc
        else:
            mse = mae = 9999
        acc /= self.num_sample
        return acc, errors, mae, mse


class CheckPredictionOfImputation2():
    """

    需要指定一个区间大小
    然后用统计量除以区间大小，得到的商取整，就是class
    delta是区间的长度

    ans.append({
        "context": context_copy.numpy(),                    # [1, T]
        "target": target_copy.numpy(),                      # [1, T]
        "imputation": imputation.detach().cpu().numpy(),    # [num_mask, T]
        "prediction": prediction.detach().cpu().numpy(),    # [num_mask, T]
        "loss_imp": loss_imp.item(),                        # scalar
        "loss_pred": loss_pred.item(),                      # scalar
        "args": self.args
    })

    """
    def __init__(self, ans:list, attack_index:int=-1, delta:float=3.0):
        self.ans = ans
        self.attack_index = attack_index
        self.delta = delta
        self.num_mask = self.ans[0]['prediction'].shape[0]
        self.num_sample = len(self.ans)
        self.MAX = 9999
        self.MIN = -9999
    
    def run(self):
        mae = 0
        mse = 0
        acc = 0
        errors = []

        for sample_index in range(self.num_sample):
            sample = self.ans[sample_index]

            sample['target'] -= self.MIN
            sample['prediction'] -= self.MIN

            real_statistic = sample['target'][0,self.attack_index].item()
            real_label = int(real_statistic // self.delta)
            
            # calculate label for each imputation
            record = {}
            for mask_index in range(self.num_mask):
                imp_label = int(sample['prediction'][mask_index,self.attack_index].item() // self.delta)
                if record.get(imp_label) is None:
                    record[imp_label] = 1
                else:
                    record[imp_label] += 1
            
            num_not_zero_interval = 0
            most_label = None
            max_count = self.MIN
            for label, count in record.items():
                if count != 0:
                    num_not_zero_interval += 1
                    if count > max_count:
                        max_count = count
                        most_label = label
            if num_not_zero_interval == 1 and most_label == real_label:
                acc += 1
                middle = (real_label + 0.5) * self.delta
                mae += abs(real_statistic - middle)
                mse += (real_statistic - middle)**2
            else:
                errors.append(sample_index)
        
        if acc == 0:
            return 0, self.MAX, self.MAX, errors
        else:
            return acc/self.num_sample, mae/acc, mse/acc, errors


if __name__ == "__main__":

    dir_of_iap_pkl = 'result-iap-random_mask-0921'
    df_path = 'result-iap-random_mask-0921.csv'
    
    # df = pd.read_csv(df_path)
    df = pd.DataFrame()
    count = 0

    # for use_filter in [False, True]:
    for use_filter in [True]:

        for rt_noise in [0.1]:
            # for dataset_name in DATASET_NAMES:
            for dataset_name in DATASET_NAMES[0:]:
                for mask in MASK_LENGTH:

                    for imp_type in ['mixer']:
                    # for imp_type in ABLATION_STUDY_IMP_MODEL_TYPES:

                        for pre_type in PRE_MODEL_TYPES:
                        # for pre_type in ['resnet18']:

                            for step in range(1,mask+1):
                                # for len_clip in LEN_CLIPS:
                                for len_clip in [1.0]:

                                    # ans = load_pkl(f'{dir_of_iap_pkl}/{dataset_name}/imp={imp_type}/mask_length={mask}/step={step}/pre={pre_type}/filter={use_filter}/ans.pkl')
                                    # ans = load_pkl(f'{dir_of_iap_pkl}/RT={rt}/{dataset_name}/imp={imp_type}/mask_length={mask}/step={step}/pre={pre_type}/filter={use_filter}/ans.pkl')
                                    ans = load_pkl(f'{dir_of_iap_pkl}/{dataset_name}/imp={imp_type}/mask_length={mask}/step={step}/pre={pre_type}/filter={use_filter}/rt_noise={rt_noise}/len_clip={len_clip}/ans.pkl')
                                    for delta in DELTAS:
                                        cpi = CheckPredictionOfImputation2(deepcopy(ans),-1,delta)
                                        acc, mae, mse, _ = cpi.run()
                                        df = df.append({
                                            "acc": acc,
                                            "mse": mse,
                                            "mae": mae,
                                            "mask": mask,
                                            "def_len": mask,
                                            "step": step,
                                            "attack_length": mask - step + 1,
                                            "atk_len": mask - step + 1,
                                            "method": "MIA",
                                            "delta": delta,
                                            "dataset_name": dataset_name,
                                            "pre_type": pre_type,
                                            "imp_type": imp_type,
                                            "use_filter": use_filter,
                                            "rt_noise": rt_noise,
                                            "len_clip": len_clip,
                                        },ignore_index=True)
                                        count += 1
                                        print(count)

    df.to_csv(df_path,index=False)
