# train my model
from argparse import ArgumentParser
from os.path import isdir, join
import numpy as np
from dataset.dataset import MyDataset
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from my_utils.utils import DATASET_CLASSIFICATION, LEN_CLIPS, MASK_LENGTH, get_dataset_params, save_json, Logger, DATASET_NAMES, mkdir
from my_utils.utils import GenerateData2 as GD
from matplotlib import pyplot as plt
from tqdm import tqdm
import random
from imputation_model.model import MLPMixer as Mixer
from imputation_model.model import MLPModel as MLP
from imputation_model.model import MyTransformer
import os
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import time
from einops import repeat, rearrange, reduce


def get_metric(imputation, intact, mask, mask_length):
    '''
    tensor.shape == [num_mask, L]
    mask == 1 means preservation, so compute metric on mask == 0
    '''
    mae = ((imputation-intact)*(1-mask)).abs().sum(dim=-1)  / mask_length
    mse = ((imputation-intact)*(1-mask)).pow(2).sum(dim=-1) / mask_length
    mse = mse.mean().item()
    mae = mae.mean().item()
    return {
        "MAE": mae,
        "MSE": mse
    }


def get_args():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--num_epoch", type=int, default=20)
    arg_parser.add_argument("--device", type=str, default="cuda:1")
    arg_parser.add_argument("--dataset_name", type=str, \
        default='MiddlePhalanxOutlineCorrect')

    arg_parser.add_argument("--output_dir", type=str, \
        default='result_imp_debug-MiddlePhalanxOutlineCorrect-' + str(time.time()))

    arg_parser.add_argument("--model_type", type=str, default='mixer')

    # Model settings
    arg_parser.add_argument("--drop_out", type=float, default=0.3)
    arg_parser.add_argument("--num_layer", type=int, default=4)
    arg_parser.add_argument("--norm", type=str, default="ln")
    arg_parser.add_argument("--hidden_dim", type=int, default=129)
    arg_parser.add_argument("--nhead", type=int, default=1, help='nhead for transformer')


    # filter
    arg_parser.add_argument("--use_filter", type=int, default=0)
    arg_parser.add_argument("--window", type=int, default=15)
    arg_parser.add_argument("--order", type=int, default=5)

    # number of batch == number of samples
    arg_parser.add_argument("--num_train_sample", type=int, default=10000)
    arg_parser.add_argument("--num_test_sample", type=int, default=300)

    # others
    arg_parser.add_argument("--only_eval", type=int, default=0)
    arg_parser.add_argument("--checkpoint", type=str, default="")
    arg_parser.add_argument("--epoch_gap", type=int, default=500, help="How often epochs to save checkpoint")
    arg_parser.add_argument("--mask_length", type=int, default=20)
    arg_parser.add_argument("--num_workers", type=int, default=4)
    arg_parser.add_argument("--lr", type=float, default=1e-3)
    arg_parser.add_argument("--lr_decay_rate", type=float, default=0.92)
    arg_parser.add_argument("--record_path", type=str, default="record-imputation.csv")
    arg_parser.add_argument("--batch_size", type=int, default=32)
    
    # random training
    arg_parser.add_argument("--rt_noise", type=float, default=0.1)

    # len clip
    arg_parser.add_argument("--len_clip", type=float, default=1.0)

    args = arg_parser.parse_args()
    return args


def post_process_args(args):
    # other settings
    args.train = False if args.only_eval else True
    args.use_filter = True if args.use_filter else False
    args.context_length, args.prediction_length = get_dataset_params(args.dataset_name, args.len_clip)
    if not args.train:
        args.num_epoch = 1
    return args


def get_loss(ctx, imp, imp_tmp):
    loss_fn = torch.nn.MSELoss()
    return loss_fn(ctx,imp) + loss_fn(ctx, imp_tmp)


def draw(context,imputation,fig_path,mask_length,summary_writer,epoch,mode):
    """
    Input:
        tensor.shape == imputation.shape == [num_mask,T]
    """
    row = 3
    col = 4
    figsize = (20,10)
    num_mask = imputation.shape[0]
    fig, axes = plt.subplots(row, col, figsize=figsize)

    num = min(row*col, num_mask)

    t = np.arange(context.shape[-1])
    for i,mask_index in enumerate(np.linspace(0,num_mask-1,num)):
        mask_index = int(mask_index)
        r = i // col
        c = i % col
        axes[r][c].plot(t,context[mask_index],'-b',label='gt')
        axes[r][c].plot(t,imputation[mask_index],'-r',label='imputation')
        # axes[r][c].axvline(mask_index-1,color='black')
        # axes[r][c].axvline(mask_index+mask_length+1,color='black')
        axes[r][c].legend()
    
    summary_writer.add_figure(
        f'{mode}',
        fig,
        epoch
    )

    return True


def run_epoch(model,mode,data_loader,optimizer,logger,epoch,fig_dir,checkpoint_dir,args,summary_writer, lr_scheduler):
    if mode == 'train':
        enable_grad = True
        model.train()
    else:
        enable_grad = False
        model.eval()
    
    with torch.set_grad_enabled(enable_grad):
        epoch_loss = mse = mae = 0.0
        # randomly select a batch to draw
        draw_iter_index = random.randint(0,len(data_loader)-1)
        for idx, (context, target) in tqdm(enumerate(data_loader)):
            count = idx + 1
            # BT
            context, target = context.to(args.device), target.to(args.device)
            context, target, mean, std = data_loader.dataset.normalize(context, target)

            # BT
            mask = torch.cat([
                torch.zeros(args.mask_length),
                torch.ones(args.context_length - args.mask_length),
            ])
            mask = mask[torch.randperm(mask.shape[0])]
            mask = repeat(mask,'t -> b t', b=context.shape[0]).to(context.device)
            # context, mask = data_loader.dataset.get_mask(context.squeeze(0),args.mask_length)
            # mask = mask.to(args.device)

            # random training
            if enable_grad:
                ctx_rt = context + args.rt_noise * torch.randn_like(context).to(context.device)
            else:
                ctx_rt = context

            imputation, imp_tmp = model(ctx_rt, mask)

            loss = get_loss(context, imputation, imp_tmp)


            if mode == 'train':
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            metric = get_metric(imputation, context, mask, args.mask_length)
            mae += metric['MAE']
            mse += metric['MSE']

            
            if idx == draw_iter_index:
                draw(
                    context.detach().cpu().numpy(),
                    imputation.detach().cpu().numpy(),
                    join(fig_dir,f"epoch_{epoch}_{mode}.png"),
                    args.mask_length,
                    summary_writer,
                    epoch,
                    mode
                )

        epoch_loss /= count
        mae /= count
        mse /= count

        logger.log("Epoch=%d, mode=%s, loss=%.6f" % (
            epoch,
            mode,
            epoch_loss
        ))

        logger.log("Epoch=%d, mode=%s, MAE=%.6f, MSE=%.6f" % (
            epoch,
            mode,
            mae,
            mse
        ))

        summary_writer.add_scalar(f'MSE/{mode}', mse, epoch)
        summary_writer.add_scalar(f'MAE/{mode}', mae, epoch)
        summary_writer.add_scalar(f'Loss/{mode}', epoch_loss, epoch)

        if mode == 'train':
            lr_scheduler.step()

        if mode == 'train' and (epoch + 1) % args.epoch_gap == 0:
            model.eval()
            model.cpu()
            torch.save(model,join(checkpoint_dir,f"epoch_{epoch+1}.pt"))
            model.to(args.device)
        
        return mse


def main(args):
    # create dirs for logs
    dirs = []
    dirs.append(args.output_dir)
    fig_dir = join(args.output_dir, "fig")
    dirs.append(fig_dir)
    checkpoint_dir = join(args.output_dir, "checkpoint")
    dirs.append(checkpoint_dir)
    for dir in dirs:
        mkdir(dir)
    
    summary_writer = SummaryWriter(args.output_dir)

    save_json(vars(args), join(args.output_dir, "args.json"))

    # generate data for imputation
    gd = GD(args.dataset_name,args.use_filter,args.window,args.order)
    gd.generate()

    train_logger = Logger(join(args.output_dir, "train.txt"))
    test_logger = Logger(join(args.output_dir, "test.txt"))


    # create model
    # test_logger.log("Loading model...")
    if args.checkpoint == "" or args.checkpoint is None:
        if args.model_type == 'mixer':
            model = Mixer(args.context_length,args.hidden_dim,args.num_layer)
        elif args.model_type == 'mlp':
            model = MLP(args.context_length,args.hidden_dim,args.num_layer,args.norm,args.drop_out)
        elif args.model_type == 'my_trans':
            model = MyTransformer(args.context_length, args.hidden_dim, args.num_layer, args.drop_out, args.nhead)
        else:
            raise NotImplementedError(f'model_type = {args.model_type} not implemented')
    else:
        model = torch.load(args.checkpoint)
    test_logger.log(type(model))
    test_logger.log(args.output_dir)
    # test_logger.log(str(model))
    model.to(args.device)

    # test_logger.log("Loading data...")
    test_dataloader = DataLoader(
        MyDataset(args.dataset_name,"test",args.context_length,args.prediction_length,args.num_test_sample,args.use_filter,args.window,args.order),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    )
    train_dataloader = DataLoader(
            MyDataset(args.dataset_name,"train",args.context_length,args.prediction_length,args.num_train_sample,args.use_filter,args.window,args.order),
            batch_size=args.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=args.num_workers
        ) if args.train else None

    optimizer = Adam(model.parameters(),lr=args.lr) if args.train else None
    lr_scheduler = ExponentialLR(optimizer, args.lr_decay_rate, verbose=True) if args.train else None
    test_best_mse = 1e12

    for epoch in tqdm(range(args.num_epoch)):
        if args.train:
            run_epoch(model,"train",train_dataloader,optimizer,train_logger,epoch,fig_dir,checkpoint_dir,args,summary_writer, lr_scheduler)
        test_mse = run_epoch(model,"test",test_dataloader,None,test_logger,epoch,fig_dir,checkpoint_dir,args,summary_writer, lr_scheduler)

        if test_mse < test_best_mse:
            test_best_mse = test_mse
            torch.save(model.cpu(), join(checkpoint_dir, f"best.pt"))
            model.to(args.device)

    torch.save(model.cpu(),join(checkpoint_dir, "final.pt"))

    df = pd.read_csv(args.record_path)
    args.best_mse = test_best_mse
    df = pd.concat([df,pd.DataFrame([vars(args)])], axis=0, join='outer')
    df.to_csv(args.record_path, index=False)


if __name__ == "__main__":
    ################################################# single
    # args = get_args()
    # args = post_process_args(args)
    # main(args)

    ################################################# batch
    model_type = 'mixer'
    root_dir = 'result_imputation_random_mask-0921'

    for dataset_name in DATASET_NAMES + DATASET_CLASSIFICATION:
        for mask_length in MASK_LENGTH:
            for use_filter in [True]:
                for rt_noise in [0.1]:
                    # for len_clip in LEN_CLIPS:
                    for len_clip in [1.0]:
                        args = get_args()

                        args.rt_noise = rt_noise
                        args.len_clip = len_clip
                        args.num_epoch = 20
                        args.num_layer = 4
                        args.hidden_dim = 129
                        args.num_train_sample = 5000
                        args.num_test_sample = 300
                        args.num_workers = 4
                        args.lr = 1e-3
                        args.dataset_name = dataset_name
                        args.model_type = model_type
                        args.use_filter = use_filter
                        args.mask_length = mask_length
                        args.output_dir = f'{root_dir}/{dataset_name}/{model_type}/mask_length={mask_length}/filter={use_filter}/rt_noise={rt_noise}/len_clip={len_clip}'
                        args = post_process_args(args)

                        main(args)
        