# train and eval prediction model
from argparse import ArgumentParser
from os.path import isdir, join
import numpy as np
from dataset.dataset import MyDataset
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from my_utils.utils import LEN_CLIPS, PRE_MODEL_TYPES, mkdir, save_json, Logger, get_dataset_params, DATASET_NAMES
from my_utils.utils import get_metric, setup_seed
from my_utils.utils import GenerateData2 as GD
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt
from tqdm import tqdm
import random
import pandas as pd
from prediction_model.model import MLPMixer as Mixer
from prediction_model.model import MLPModel3 as MLP
from prediction_model.model import LSTM3 as LSTM
from prediction_model.model import TemporalConvNet as TCN
from prediction_model.model import GRU as GRU
from prediction_model.model import RNN as RNN
from prediction_model.model import LinearModel, ResNet18, Transformer, ResNet34
import time

def get_args():
    parser = ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="electricity_nips")

    parser.add_argument("--output_dir", type=str, \
        default="result-pred-resnet34-debug-" + str(time.time()))

    parser.add_argument("--model_type", type=str, default="resnet34")
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument("--num_epoch", type=int, default=100)


    # not set very frequently
    parser.add_argument("--epoch_gap", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--lr_decay_rate", type=float, default=0.95)
    parser.add_argument("--only_eval", type=int, default=0)
    parser.add_argument("--checkpoint", type=str, default="")
    parser.add_argument("--num_train_batch", type=int, default=1500)
    parser.add_argument("--num_test_batch", type=int, default=100)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--optimizer", type=str, default="adam")
    parser.add_argument("--record_path", type=str, default="record-prediction.csv")


    # Model settings
    parser.add_argument("--hidden_dim", type=int, default=16)
    parser.add_argument("--num_layer", type=int, default=2)
    parser.add_argument("--drop_out", type=float, default=0.1)
    parser.add_argument("--norm", type=str, default="ln", help="normalization method")
    parser.add_argument("--pred_weight", type=float, default=1.0)
    parser.add_argument("--mode", type=str, default='padding', help='mode for rnn-based model')
    parser.add_argument("--activation", type=str, default='relu', help='activation for mlp')


    # random training
    parser.add_argument("--rt_noise", type=float, default=0.1, help="Noise for Random training")

    # len_clip
    parser.add_argument("--len_clip", type=float, default=1.0)

    # filter
    parser.add_argument("--use_filter", type=int, default=1)
    parser.add_argument("--window", type=int, default=15)
    parser.add_argument("--order", type=int, default=5)

    args = parser.parse_args()
    return args


def post_process_args(args):
    # modify other args
    args.context_length, args.prediction_length = get_dataset_params(args.dataset_name, args.len_clip)
    if args.only_eval:
        args.num_epoch = 1
    args.train = False if args.only_eval else True
    args.norm = args.norm.lower()
    args.model_type = args.model_type.lower()
    args.mode = args.mode.lower()
    args.use_filter = True if args.use_filter else False
    return args


def draw(target,prediction,fig_path,summary_writer,epoch,mode):
    row = 3
    col = 4
    fig, axes = plt.subplots(row,col,figsize=(8,4.5))
    # fig, axes = plt.subplots(row,col)
    T = np.arange(target.shape[-1])
    for i in range(min(target.shape[0], row*col)):
        r = i // col
        c = i % col
        axes[r][c].plot(T,target[i],'-b',label='gt')
        axes[r][c].plot(T,prediction[i],'-r',label='pred')
        axes[r][c].legend()
    summary_writer.add_figure(
        f'{mode}',
        fig,
        epoch
    )
    return True


def get_loss(tgt, pred):
    loss_fn = torch.nn.MSELoss()
    return loss_fn(tgt, pred)


def run_epoch(model,mode,data_loader,optimizer,logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler):
    if mode == 'train':
        enable_grad = True
        model.train()
    else:
        enable_grad = False
        model.eval()
    

    mse = mae = epoch_loss = 0.0

    with torch.set_grad_enabled(enable_grad):
        # randomly select a batch to draw
        draw_iter_index = random.randint(0,len(data_loader)-1)
        for idx, (context, target) in tqdm(enumerate(data_loader)):
            count = idx + 1
            context, target = context.to(args.device), target.to(args.device)
            context, target, mean, std = data_loader.dataset.normalize(context, target)

            if mode == 'train':
                context = context + args.rt_noise * torch.randn_like(context).to(args.device)

            prediction = model(context, target)

            loss = get_loss(target, prediction)

            if mode == 'train':
                loss.backward()
                optimizer.step()
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.zero_grad()
            
            epoch_loss += loss.item()
            metric = get_metric(prediction,target)
            mae += metric['mae']
            mse += metric['mse']


            # if args.de_normalize:
                # context, target, prediction = data_loader.dataset.de_normalize(context,target,prediction,mean,std)

            if idx == draw_iter_index:
                draw(
                    target.detach().cpu().numpy(),
                    prediction.detach().cpu().numpy(),
                    join(fig_dir,f"epoch_{epoch}_{mode}.png"),
                    summary_writer,
                    epoch,
                    mode
                )

        epoch_loss /= count
        mse /= count
        mae /= count

        logger.log("Epoch=%d, mode=%s, loss=%.6f" % (
            epoch,
            mode,
            epoch_loss
        ))
        
        logger.log("Epoch=%d, mode=%s, MAE=%.6f, MSE=%.6f" % (
            epoch,
            mode,
            mae,
            mse
        ))

        summary_writer.add_scalar(f'MSE/{mode}', mse, epoch)
        summary_writer.add_scalar(f'MAE/{mode}', mae, epoch)
        summary_writer.add_scalar(f'Loss/{mode}', epoch_loss, epoch)

        if mode == 'train':
            lr_scheduler.step()

        if mode == 'train' and (epoch + 1) % args.epoch_gap == 0:
            torch.save(model.cpu(),join(checkpoint_dir,f"epoch_{epoch+1}.pt"))
            model.to(args.device)

        return mse
        

def main(args):
    setup_seed()
    # get args
    
    # create dirs
    dirs = []
    dirs.append(args.output_dir)
    fig_dir = join(args.output_dir, "fig")
    dirs.append(fig_dir)
    checkpoint_dir = join(args.output_dir, "checkpoint")
    dirs.append(checkpoint_dir)
    for dir in dirs:
        mkdir(dir)
    
    # save args
    save_json(vars(args), join(args.output_dir, "args.json"))

    # generate data
    gd = GD(args.dataset_name,args.use_filter,args.window,args.order)
    gd.generate()

    train_logger = Logger(join(args.output_dir, "train.txt"))
    test_logger = Logger(join(args.output_dir, "test.txt"))
    summary_writer = SummaryWriter(args.output_dir)

    # test_logger.log("Loading model...")
    if args.checkpoint == "":
        if args.model_type == "mixer":           
            model = Mixer(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer)
        elif args.model_type == "mlp":
            model = MLP(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer, args.norm, args.drop_out, args.activation)
        elif args.model_type == "lstm":
            model = LSTM(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "gru":
            model = GRU(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "rnn":
            model = RNN(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == "tcn":
            model = TCN(args.context_length, args.prediction_length, args.hidden_dim, args.num_layer, args.drop_out)
            model = DeepAR(args.hidden_dim, args.num_layer, args.drop_out, args.pred_weight, args.mode)
        elif args.model_type == 'resnet18':
            model = ResNet18(args.context_length, args.prediction_length)
        elif args.model_type == 'resnet34':
            model = ResNet34(args.context_length, args.prediction_length)
        else:
            raise NotImplementedError(f"model_type = {args.model_type} not implemented")
    else:
        test_logger.log(f'Loading checkpoint: {args.checkpoint}')
        model = torch.load(args.checkpoint)
    test_logger.log(args.output_dir)
    test_logger.log(f"Model type is {type(model)}")
    model.to(args.device)


    test_dataloader = DataLoader(
        MyDataset(args.dataset_name,"test",args.context_length,args.prediction_length,args.num_test_batch*args.batch_size,args.use_filter,args.window,args.order),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    )
    train_dataloader = DataLoader(
        MyDataset(args.dataset_name,"train",args.context_length,args.prediction_length,args.num_train_batch*args.batch_size,args.use_filter,args.window,args.order),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    ) if args.train else None

    if args.optimizer == 'adam':
        optimizer = Adam(model.parameters(),lr=args.lr) if args.train else None
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),lr=args.lr) if args.train else None
    else:
        raise NotImplementedError(f'optimizer={args.optimizer} is not implemented')
    lr_scheduler = ExponentialLR(optimizer, args.lr_decay_rate, verbose=True)
    test_mse_best = 1e12

    for epoch in tqdm(range(args.num_epoch)):
        if args.train:
            run_epoch(model,"train",train_dataloader,optimizer,train_logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler)

        test_mse = run_epoch(model,"test",test_dataloader,None,test_logger,epoch,fig_dir,checkpoint_dir,args, summary_writer, lr_scheduler)

        if test_mse < test_mse_best:
            test_mse_best = test_mse
            torch.save(model.cpu(), join(checkpoint_dir, f"best.pt"))
            model.to(args.device)

    torch.save(model.cpu(),join(checkpoint_dir, "final.pt"))

    
    df = pd.read_csv(args.record_path)
    args.best_mse = test_mse_best
    df = pd.concat([df,pd.DataFrame([vars(args)])], axis=0, join='outer')
    df.to_csv(args.record_path, index=False)


if __name__ == "__main__":
    ############################################################ single
    # args = get_args()
    # args = post_process_args(args)
    # main(args)


    ############################################################ batch
    root_dir = 'result_prediction_0827_traffic'
    # for dataset_name in DATASET_NAMES:
    for dataset_name in ['traffic_nips']:
        for model_type in PRE_MODEL_TYPES:
            for use_filter in [True]:
                for rt_noise in [0.1, 0.0]:
                    for len_clip in LEN_CLIPS:

                        args = get_args()

                        args.device = "cuda:2"
                        args.rt_noise = rt_noise
                        args.len_clip = len_clip
                        args.dataset_name = dataset_name
                        args.model_type = model_type
                        args.use_filter = use_filter
                        args.output_dir = f'{root_dir}/{dataset_name}/{model_type}/filter={use_filter}/rt_noise={rt_noise}/len_clip={len_clip}'


                        if model_type == 'mlp':
                            args.hidden_dim = 129
                            args.num_layer = 4
                            args.drop_out = 0.1
                            args.num_epoch = 60
                        elif model_type == 'mixer':
                            args.hidden_dim = 33
                            args.num_layer = 16
                            args.drop_out = 0.5
                            args.num_epoch = 20
                        elif model_type == 'lstm':
                            args.hidden_dim = 129
                            args.num_layer = 4
                            args.drop_out = 0.1
                            args.num_epoch = 50
                        elif model_type == 'gru':
                            args.hidden_dim = 129
                            args.num_layer = 4
                            args.drop_out = 0.1
                            args.num_epoch = 50
                        elif model_type == 'resnet18':
                            args.hidden_dim = -1
                            args.num_layer = -1
                            args.drop_out = -1
                            args.num_epoch = 20
                        else:
                            raise NotImplementedError(f'pre type = {model_type} is not implemented')

                        args = post_process_args(args)
                        main(args)
