from argparse import ArgumentParser
from os.path import isdir, join
from dataset.dataset import MyDataset
import torch
from torch.utils.data import DataLoader
from imputation_model.model import PyPOTS_IMP_Model
from my_utils.utils import ABLATION_STUDY_IMP_MODEL_TYPES, DATASET_NAMES, LEN_CLIPS, MASK_LENGTH, PRE_MODEL_TYPES, PYPOTS_IMP_MODEL_TYPES,get_dataset_params, mkdir, save_json, save_pkl, setup_seed
from my_utils.utils import load_pkl
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
from pypots.data import fill_nan_with_mask


def get_args():


    arg_parser = ArgumentParser()
    arg_parser.add_argument("--mask_length", type=int, default=8)
    arg_parser.add_argument("--dataset_name", type=str, default='electricity_nips')

    arg_parser.add_argument("--imputation_model", type=str, default='result_imputation_0806/electricity_nips/mixer/mask_length=8/filter=True/checkpoint/best.pt')

    arg_parser.add_argument("--prediction_model", type=str, default='result-pred-resnet34-debug-1660733364.256375/checkpoint/best.pt')

    arg_parser.add_argument("--output_dir", type=str, default='result-iap-debug-0813')

    arg_parser.add_argument("--device", type=str, default="cuda:1")
    arg_parser.add_argument("--num_sample", type=int, default=101, help="num_sample of dataset")
    arg_parser.add_argument("--step", type=int, help="the step of mask", default=1)

    # filter settings
    arg_parser.add_argument("--use_filter", type=int, default=1)
    arg_parser.add_argument("--window", type=int, default=15)
    arg_parser.add_argument("--order", type=int, default=5)

    arg_parser.add_argument("--imp_type", type=str)
    arg_parser.add_argument("--pre_type", type=str)
    arg_parser.add_argument("--metric_on_normalized", type=int, default=1, \
        help='If true, ground truth will be normalized for metric')


    args = arg_parser.parse_args()
    args.use_filter = True if args.use_filter else False
    args.metric_on_normalized = True if args.metric_on_normalized else False 
    args.context_length, args.prediction_length = get_dataset_params(args.dataset_name)
    return args


class ImputeAndPredict():
    def __init__(self,imp_model,pred_model,dataloader,args):
        self.imp_model = imp_model
        try:
            self.imp_model.to(args.device).eval()
        except:
            pass
        print(type(imp_model))
        self.pred_model = pred_model.to(args.device).eval()
        self.dataloader = dataloader
        self.args = args
    
    def get_mask(self,context):
        """
        T -> NT,NT

        Has mask_length and step
        """
        context, mask = self.dataloader.dataset.get_mask(context, self.args.mask_length, self.args.step)
        context = context.to(self.args.device)
        mask = mask.to(self.args.device)
        return context, mask
    
    def normalize(self,context,target):
        """NT,1T -> NT,NT,N1,N1"""
        target = target.repeat(context.shape[0],1)
        context_normalized, target_normalized, mean, std = self.dataloader.dataset.normalize(context,target)
        return context_normalized, target_normalized, mean, std
    
    def impute(self,context,mask):
        """NT,NT -> NT,scalar"""
        imputation, _ = self.imp_model(context, mask)
        return imputation
    
    def predict(self,imputation,target):
        """NT,NT -> NT, scalar"""
        pred = self.pred_model(imputation,target)
        return pred
    
    def de_normalize(self,context,imputation,prediction,mean,std,eps=1e-12):
        """
        Input:
            context, imputation, prediction: shape is NT
            mean, std: shape is N1
        Output:
            context, imputation, prediction: shape is NT
        """
        context = context * (std+eps) + mean
        imputation = imputation * (std+eps) + mean
        prediction = prediction * (std+eps) + mean
        return context, imputation, prediction
    
    def single_run(self,context,target):
        """1T,1T -> NT,NT,scalar,scalar"""
        context, target = context.to(self.args.device), target.to(self.args.device)
        context, mask = self.get_mask(context.squeeze(0))
        context, target, mean, std = self.normalize(context,target)
        imputation = self.impute(context,mask)
        prediction = self.predict(imputation,target)
        if self.args.metric_on_normalized:
            pass
        else:
            context, imputation, prediction = self.de_normalize(context,imputation,prediction,mean,std)
        return imputation, prediction
    
    def run(self):
        ans = []
        # context and target, shape are [1,T]
        for context,target in tqdm(self.dataloader):
            # _copy's shape are [1,T]
            if self.args.metric_on_normalized:
                context_copy,target_copy,_,_ = self.normalize(context,target)
                target_copy = target_copy[0:1,:].detach().cpu()
                context_copy = context_copy.detach().cpu()
            else:
                context_copy = context.clone().detach().cpu()
                target_copy = target.clone().detach().cpu()

            imputation, prediction = self.single_run(context,target)
            ans.append({
                "context": context_copy.numpy(),                    # [1, T]
                "target": target_copy.numpy(),                      # [1, T]
                "imputation": imputation.detach().cpu().numpy(),    # [num_mask, T]
                "prediction": prediction.detach().cpu().numpy(),    # [num_mask, T]
                "args": self.args
            })
        
        return ans


def main(args):
    setup_seed()

    if not isdir(args.output_dir):
        os.makedirs(args.output_dir)

    save_path = join(args.output_dir, "ans.pkl")
    save_json(vars(args), join(args.output_dir, "args.json"))


    print(f"Loading dataset {args.dataset_name}...")
    dataset = MyDataset(args.dataset_name,"test",args.context_length,args.prediction_length,args.num_sample,args.use_filter,args.window,args.order)
    dataloader = DataLoader(dataset,batch_size=1,drop_last=True,shuffle=True)


    print(f"Loading {args.imputation_model}...")
    imp_model = torch.load(args.imputation_model)
    print(f"Loading {args.prediction_model}...")
    pred_model = torch.load(args.prediction_model)


    if args.imp_type in PYPOTS_IMP_MODEL_TYPES:
        imp_model = PyPOTS_IMP_Model(imp_model)


    iap = ImputeAndPredict(imp_model,pred_model,dataloader,args)
    ans = iap.run()

    print(f"Saving {save_path}...")
    save_pkl(ans,save_path)


if __name__ == "__main__":
    ############################################## single
    # args = get_args()
    # main(args)

    ############################################## batch
    cuda = 1
    # imp_model_dir = 'result_imputation_0827'
    imp_model_dir = 'result_imputation_random_mask-0921'
    pred_model_dir = 'result_prediction_0827'
    iap_dir = 'result-iap-random_mask-0921'


    args = get_args()
    args.num_sample = 101
    args.metric_on_normalized = True


    # for use_filter in [True, False]:
    for use_filter in [True]:
        for dataset_name in DATASET_NAMES[0:]:
            for pre_type in PRE_MODEL_TYPES[0:]:
                # for imp_type in ABLATION_STUDY_IMP_MODEL_TYPES:
                for imp_type in ['mixer']:
                    for mask_length in MASK_LENGTH:
                        for step in range(1,mask_length+1):
                            for rt_noise in [0.1]:
                                # for len_clip in LEN_CLIPS:
                                for len_clip in [1.0]:
                                    imp = f'{imp_model_dir}/{dataset_name}/{imp_type}/mask_length={mask_length}/filter={use_filter}/rt_noise={rt_noise}/len_clip={len_clip}/checkpoint/best.pt'
                                    pre = f'{pred_model_dir}/{dataset_name}/{pre_type}/filter={use_filter}/rt_noise=0.1/len_clip={len_clip}/checkpoint/best.pt'
                                    output_dir = f'{iap_dir}/{dataset_name}/imp={imp_type}/mask_length={mask_length}/step={step}/pre={pre_type}/filter={use_filter}/rt_noise={rt_noise}/len_clip={len_clip}'

                                    args.dataset_name = dataset_name
                                    args.context_length, args.prediction_length = get_dataset_params(args.dataset_name, len_clip)
                                    args.imputation_model = imp
                                    args.prediction_model = pre
                                    args.output_dir = output_dir
                                    args.device = f'cuda:{cuda}'
                                    args.pre_type = pre_type
                                    args.imp_type = imp_type
                                    args.step = step
                                    args.mask_length = mask_length
                                    args.use_filter = use_filter
                                    args.rt_noise = rt_noise
                                    args.len_clip = len_clip

                                    main(args)
                                    print(output_dir)