from os.path import join
from pypots.data import fill_nan_with_mask
from pypots.imputation import SAITS
from pypots.imputation import LOCF, BRITS, Transformer
import numpy as np
from my_utils.utils import DATASET_CLASSIFICATION, DATASET_NAMES, MASK_LENGTH, PYPOTS_IMP_MODEL_TYPES, get_dataset_params, mkdir
from dataset.dataset import MyDataset
import torch
from torch.utils.data import DataLoader


window = 15
order = 5
num_train = 100
num_epoch = 5
device = "cuda:1"
root_dir = "result_imputation_0806"

# for dataset_name in DATASET_NAMES[::-1]:
for dataset_name in DATASET_CLASSIFICATION:
    for imp_type in PYPOTS_IMP_MODEL_TYPES:
        for mask_length in MASK_LENGTH:
            for use_filter in [True]:
                context_length, prediction_length = get_dataset_params(dataset_name)
                train_dataset = MyDataset(dataset_name,'train',context_length,prediction_length,num_train,use_filter,window,order)
                train_loader = DataLoader(train_dataset,1)
                intact_train = []
                masks_train = []
                for context, target in train_loader:
                    context, target, _, _ = train_dataset.normalize(context,target)
                    context.squeeze_(0)
                    context, mask = train_dataset.get_mask(context,mask_length)
                    intact_train.append(context)
                    masks_train.append(mask)
                intact_train = torch.cat(intact_train).unsqueeze(-1).numpy()
                masks_train = torch.cat(masks_train).unsqueeze(-1).numpy()
                masked_data_train = fill_nan_with_mask(intact_train.copy(),masks_train)   

                if imp_type == 'saits':
                    model = SAITS(n_layers=2, d_model=32, d_inner=16, n_head=4, d_k=8, d_v=8, dropout=0.1, epochs=num_epoch,device=device)
                elif imp_type == 'brits':
                    model = BRITS(rnn_hidden_size=32,epochs=num_epoch,device=device)
                elif imp_type == 'transformer':
                    Transformer(n_layers=2, d_model=32, d_inner=16, n_head=4, d_k=8, d_v=8, dropout=0.1, epochs=num_epoch,device=device)
                else:
                    raise NotImplementedError(f'imp_type = {imp_type} not implemented')
                
                model.fit(masked_data_train)

                tgt_dir = f'{root_dir}/{dataset_name}/{imp_type}/mask_length={mask_length}/filter={use_filter}/checkpoint'
                mkdir(tgt_dir)
                torch.save(model, join(tgt_dir,'best.pt'))
