import torch
from argparse import ArgumentParser
from tqdm import tqdm
from my_utils.utils import DATASET_CLASSIFICATION, EPOCHS_TSC, TSC_MODEL_TYPES, save_json, Logger, setup_seed, mkdir
from os.path import join
from classification_model.model import MLPModel as MLP
from classification_model.model import ResNet
from classification_model.model import FCN
from classification_model.mixer import MLPMixer
from dataset.dataset import DatasetTSC
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import time

def get_args():
    parser = ArgumentParser()

    parser.add_argument('--dataset_name', type=str, default='ProximalPhalanxTW')
    parser.add_argument('--output_dir', type=str, default='result-classify-debug')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num_epoch', type=int, default=20)
    parser.add_argument('--num_layer', type=int, default=4)
    parser.add_argument('--hidden_dim', type=int, default=96)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--use_filter', type=int, default=1)
    parser.add_argument('--window', type=int, default=15)
    parser.add_argument('--order', type=int, default=5)
    parser.add_argument('--model_type', type=str, default='mixer')
    parser.add_argument('--checkpoint', type=str, default='')
    parser.add_argument('--train', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--rt_noise', type=float, default=0.0)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--lr_decay_rate', type=float, default=0.98)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--normalization_dataset', type=str, default='instance')
    parser.add_argument('--normalization_model', type=str, default='ln')
    parser.add_argument('--record_path', type=str, default='record-classification.csv')
    parser.add_argument('--activation', type=str, default='relu')

    args = parser.parse_args()
    return args

def post_process_args(args):
    args.use_filter = True if args.use_filter else False
    args.train = True if args.train else False
    args.num_epoch = args.num_epoch if args.train else 1

    return args

def main(args):
    setup_seed()

    mkdir(args.output_dir)
    ck_dir = join(args.output_dir,'checkpoint')
    mkdir(ck_dir)

    save_json(vars(args),join(args.output_dir,'args.json'))

    test_logger = Logger(join(args.output_dir,'test.txt'))
    train_logger = Logger(join(args.output_dir,'train.txt'))

    summary_writer = SummaryWriter(args.output_dir)

    train_dataloader = DataLoader(
        DatasetTSC(
            dataset_name=args.dataset_name,
            mode = 'train',
            normalization_method=args.normalization_dataset,
            use_filter=args.use_filter,
            window=args.window,
            order=args.order,
        ),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        drop_last=True
    )

    test_dataloader = DataLoader(
        DatasetTSC(
            dataset_name=args.dataset_name,
            mode = 'test',
            normalization_method=args.normalization_dataset,
            use_filter=args.use_filter,
            window=args.window,
            order=args.order,
        ),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        drop_last=True
    )

    context_length, num_class = test_dataloader.dataset.get_info()

    test_logger.log('{:<20} == {}'.format('dataset_name',args.dataset_name))
    test_logger.log('{:<20} == {}'.format('num_class',num_class))
    test_logger.log('{:<20} == {}'.format('context_length',context_length))
    test_logger.log('{:<20} == {}'.format('use_filter',args.use_filter))
    test_logger.log('{:<20} == {}'.format('length_train',len(train_dataloader.dataset)))
    test_logger.log('{:<20} == {}'.format('length_test',len(test_dataloader.dataset)))
    if args.use_filter:
        test_logger.log('{:<20} == {}'.format('window',args.window))
        test_logger.log('{:<20} == {}'.format('order',args.order))


    if args.checkpoint != '':
        test_logger.log(f'Loading checkpoint {args.checkpoint}...')
        model = torch.load(args.checkpoint)
    else:
        if args.model_type == 'mlp':
            model = MLP(context_length, args.hidden_dim, num_class, args.num_layer, args.normalization_model, args.dropout, args.activation)
        elif args.model_type == 'resnet18':
            model = ResNet(args.model_type, num_class)
        elif args.model_type == 'fcn':
            model = FCN(args.hidden_dim, args.num_layer, args.dropout, num_class, args.activation)
        elif args.model_type == 'mixer':
            model = MLPMixer(context_length, args.hidden_dim, args.num_layer, num_class, args.dropout)
        else:
            raise NotImplementedError(f'Model type: {args.model_type} is not implemented')


    test_logger.log(f'Model type is {str(type(model))}')
    model.to(args.device)

    optimizer = None
    scheduler = None
    if args.train:
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),lr=args.lr)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        elif args.optimizer == 'adadelta':
            optimizer = torch.optim.Adadelta(model.parameters(),lr=args.lr)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(),lr=args.lr)
        else:
            raise NotImplementedError
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,args.lr_decay_rate,verbose=True)


    best_acc = 0
    for epoch in (range(args.num_epoch)):
        if args.train:
            run_epoch('train',epoch,model,train_dataloader,optimizer,scheduler,train_logger,summary_writer,args)
        loss, acc = run_epoch('test',epoch,model,test_dataloader,None,None,test_logger,summary_writer,args)
        if acc > best_acc:
            best_acc = acc
            if args.train:
                model.cpu()
                torch.save(model,join(ck_dir,'best.pt'))
                model.to(args.device)
    if args.train:
        torch.save(model.cpu(), join(ck_dir,'final.pt'))
    args.acc = best_acc

    try:
        df = pd.read_csv(args.record_path)
    except:
        df = pd.DataFrame()
    df = pd.concat([df,pd.DataFrame([vars(args)])])
    df.to_csv(args.record_path,index=False)
    


def compute_loss(y,y_hat):
    '''
    y.shape = [B,]
    y_hat.shape = [B,num_class]
    '''
    loss_fn = torch.nn.CrossEntropyLoss()
    return loss_fn(y_hat,y)


def run_epoch(
    mode,
    epoch,
    model,
    dataloader,
    optimizer,
    scheduler,
    logger,
    summary_writer,
    args,
):  
    time_start = time.time()
    if mode == 'train':
        model.train()
        train = True
    else:
        model.eval()
        train = False
    
    num_sample = num_true = total_loss = num_batch = 0
    
    with torch.set_grad_enabled(train):
        for context, label in (dataloader):
            context = context.to(args.device)
            label = label.to(args.device)

            if train:
                context = context + args.rt_noise * torch.randn_like(context).to(args.device)

            logits = model(context)
            loss = compute_loss(label,logits)

            if train:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            num_batch += 1
            num_sample += label.shape[0]
            num_true += (logits.argmax(dim=-1) == label).sum().item()

            # if train:
                # print(f'Pred = {logits.argmax(dim=-1)}')
                # print(f'True = {label}')
    if train:
        scheduler.step()


    acc = num_true / num_sample
    total_loss /= num_batch
    summary_writer.add_scalar(f'Loss/{mode}',total_loss,epoch)
    summary_writer.add_scalar(f'ACC/{mode}',acc,epoch)

    duration = time.time() - time_start
    logger.log('Mode={:<5}\tEpoch={:<3}\tLoss={:<6.3f}\tACC={:.3f}\tTime={:<.3f}(s)'.format(mode,epoch,total_loss,acc,duration))

    return total_loss, acc


if __name__ == '__main__':
    ########################################## single
    # args = get_args()
    # args = post_process_args(args)
    # main(args)

    ########################################## batch run
    args = get_args()
    tsc_dir = 'result-tsc-0826'
    for dataset_name in DATASET_CLASSIFICATION:
        for model_type in TSC_MODEL_TYPES:
            for use_filter in [True, False]:
                for rt_noise in [0.1, 0.0]:
                    # args.record_path = f'record-tsc-wo-defense-0826.csv'
                    args.dataset_name = dataset_name
                    args.num_epoch = EPOCHS_TSC[model_type]
                    args.model_type = model_type
                    args.use_filter = use_filter
                    args.rt_noise = rt_noise
                    args.output_dir = f'{tsc_dir}/{dataset_name}/{model_type}/filter={use_filter}/rt_noise={rt_noise}'
                    # args.output_dir = 'tmp'
                    # args.train = False
                    # args.checkpoint = f'{tsc_dir}/{dataset_name}/{model_type}/filter={use_filter}/rt_noise={rt_noise}/checkpoint/best.pt'
                    args = post_process_args(args)
                    main(args)
                    # print(args.checkpoint)
                    print(args.output_dir)