import numpy as np
import argparse
from torch.utils.data import DataLoader
import time

from utils import *
from dataloader import UCR_load, semi_setting, Dataset
from models import MLPClassifier, MaskedDualTemporalAutoencoder

import warnings
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='Masked Dual-Temporal Autoencoders')
parser.add_argument('--seed', type=int, default=42, help='The random seed (defaults to 42)')
parser.add_argument('--dataset', type=str, default='CBF', help="The dataset name (defaults to 'CBF')")
parser.add_argument('--label_ratio', type=float, default=0.1, help='The proportion of labeled instances within the dataset (defaults to 0.1)')
parser.add_argument('--max_len', type=int, default=1024, help='The maximum lenght of the incoming sequence (defaults to 1024)')
parser.add_argument('--embed_dim', type=int, default=64, help='The dimension of representation obtained from dual-temporal encoder (defaults to 64)')
parser.add_argument('--depth', type=int, default=4, help='The depths in multi-resolution sub-encoder (defaults to 4)')
parser.add_argument('--hidden_dim', type=int, default=256, help='The dimension of hidden layers used in MDTA (defaults to 256)')
parser.add_argument('--dropout', type=float, default=0.1, help='The dropout value (defaults to 0.1)')
parser.add_argument('--num_head', type=int, default=8, help='The number of heads in the multi-head attention of transformer-based sub-encoder (defatuls to 8)')
parser.add_argument('--num_layer', type=int, default=3, help='The number of transformer blocks in transformer-based sub-encoder (defaults to 3)')
parser.add_argument('--lr', type=float, default=1e-3, help='The learning rate (defaults to 1e-3)')
parser.add_argument('--batch_size', type=int, default=10, help='The batch size (defaults to 10)')
parser.add_argument('--epoch', type=int, default=1000, help='The maximum training epochs (defaults to 1000)')
parser.add_argument('--patience', type=int, default=50, help='The patience epoch for early stopping (defaults to 50)')
parser.add_argument('--gpu', type=int, default=0, help='The gpu no. used for training and inference (defaults to 0)')

args = parser.parse_args()

print("Dataset:", args.dataset)
print("Arguments:", str(args))

device = init_dl_program(args.gpu, seed=args.seed)
logger = setup_logger(name='MDTA', log_file='results/'+args.dataset+'_'+str(args.seed)+'_'+str(args.label_ratio)+'.log')

logger.info('Dataset: %s' % (args.dataset))
logger.info('Label ratio: %.1f' % (args.label_ratio))
logger.info('Seed: %i' % (args.seed))
logger.info('Patience: %i' % (args.patience))

data_name = args.dataset
train_x, train_y, test_x, test_y = UCR_load(data_name)

num_classes = len(np.unique(train_y))

# Construct dataset for semi-supervised setting
train_labeled_x, train_labeled_y, train_unlabeled_x, train_unlabeled_y, val_x, val_y, test_x, test_y = semi_setting(train_x, train_y, test_x, test_y, normalization=True, label_ratio=args.label_ratio)

# Data Loader
train_labeled_set = Dataset(train_labeled_x, train_labeled_y)
train_labeled_loader = DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_set = Dataset(train_x, train_y)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
val_set = Dataset(val_x, val_y)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
test_set = Dataset(test_x, test_y)
test_loader = DataLoader(test_set, batch_size=len(test_x), shuffle=False)

accs = []
times = []
epochs = []
for i in range(5):
    classifier = MLPClassifier(args.embed_dim, num_classes, hidden_dims=args.hidden_dim)
    classifier.to(device)
    
    # Model
    model = MaskedDualTemporalAutoencoder(
        feat_dim = train_labeled_x.shape[2], # V
        max_len = args.max_len, 
        d_model = args.embed_dim,
        n_heads = args.num_head,
        num_layers = args.num_layer,
        dim_feedforward = args.hidden_dim,
        dropout = args.dropout,
        pos_encoding = 'fixed',
        activation = 'gelu',
        norm= 'BatchNorm',
        freeze=False)

    model.to(device)

    save_path = ['model/MDTA_'+args.dataset+'_label_'+str(args.label_ratio)+'_init_adaptive_'+str(i)+'.pt',
                'model/MDTA_'+args.dataset+'_label_'+str(args.label_ratio)+'_best_adaptive_'+str(i)+'.pt',
                'model/MDTA_'+args.dataset+'_label_'+str(args.label_ratio)+'_last_adaptive_'+str(i)+'.pt']

    test_acc, best_epoch, time = semi_adaptive_train(train_labeled_loader, train_loader, val_loader, test_loader,
                                model, classifier, args.epoch, args.lr, device, save_path, args.patience)

    logger.info('BEST EPOCH: %i' % best_epoch)
    logger.info('TEST ACC: %.3f' % test_acc)
    logger.info('PROCESSING TIME: %.4f' % time)
    logger.info('----------------------------------')

    accs.append(test_acc)
    times.append(time)
    epochs.append(best_epoch)

import csv
f = open('repeat_results.csv', 'a', newline='')
wr = csv.writer(f)
wr.writerow([args.dataset, 'MDTA', args.label_ratio, sum(accs)/5, np.std(accs), sum(times)/5, sum(epochs)/5])
f.close()