import argparse
import torch
import numpy as np
from tqdm import *
import torch.optim as optim
from utils import KpiReader
from load_ett import _get_data
from models import StackedVAGT
from logger import Logger
from torch.utils.data import Dataset, DataLoader


# coeff = 0.5



class Trainer(object):
    def __init__(self, vagt, trainloader, log_path='log_trainer', log_file='loss', epochs=20,
                 batch_size=1024, learning_rate=0.001, coeff = 0.2,checkpoints='kpi_model.path', checkpoints_interval=1,
                 device=torch.device('cuda:0')):
        """
        VRNN is well trained at this moment when training VAGT
        :param vrnn:
        :param vagt:
        :param train:
        :param trainloader:
        :param log_path:
        :param log_file:
        :param epochs:
        :param batch_size:
        :param learning_rate:
        :param checkpoints:
        :param checkpoints_interval:
        :param device:
        """
        self.trainloader = trainloader
        # self.train = train
        self.log_path = log_path
        self.log_file = log_file
        self.start_epoch = 0
        self.epochs = epochs
        self.device = device
        self.batch_size = batch_size
        self.vagt = vagt
        self.vagt.to(device)
        self.learning_rate = learning_rate
        self.coeff = coeff
        self.checkpoints = checkpoints
        self.checkpoints_interval = checkpoints_interval
        print('Model parameters: {}'.format(self.vagt.parameters()))
        self.optimizer = optim.Adam(self.vagt.parameters(), self.learning_rate)
        self.epoch_losses = []
        self.loss = {}
        self.logger = Logger(self.log_path, self.log_file)

    def save_checkpoint(self, epoch, checkpoints):
        torch.save({'epoch': epoch + 1,
                    'beta': self.vagt.beta,
                    'state_dict': self.vagt.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'losses': self.epoch_losses},
                    checkpoints + '_epochs{}.pth'.format(epoch+1))

    def load_checkpoint(self, start_ep, checkpoints):
        try:
            print("Loading Chechpoint from ' {} '".format(checkpoints+'_epochs{}.pth'.format(start_ep)))
            checkpoint = torch.load(checkpoints+'_epochs{}.pth'.format(start_ep))
            self.start_epoch = checkpoint['epoch']
            self.vagt.beta = checkpoint['beta']
            self.vagt.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.epoch_losses = checkpoint['losses']
            print("Resuming Training From Epoch {}".format(self.start_epoch))
            self.start_epoch = 0
        except:
            print("No Checkpoint Exists At '{}', Starting Fresh Training".format(checkpoints))
            self.start_epoch = 0


    def train_model(self):
        self.vagt.train()
        for epoch in range(self.start_epoch, self.epochs):
            losses = []
            llhs = []
            kld_zs = []
            print("Running Epoch : {}".format(epoch + 1))
            for i, dataitem in tqdm(enumerate(self.trainloader, 1)):

                data, label = dataitem
                data = data.unsqueeze(3)
                label = label.unsqueeze(3)
                # print(data.shape)
                # print(label.shape)

                batch_size = data.size(0)
                data = data.to(self.device)
                self.optimizer.zero_grad()
                # print(data.shape)
                # z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward,h_out = self.vagt(data)
                z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward, \
                z_logvar_prior_forward, x_mu, x_logsigma, h_out = self.vagt(data)

                llh = self.vagt.loss_LLH(data.squeeze(-1), x_mu.squeeze(2).squeeze(-1), x_logsigma.squeeze(2).squeeze(-1)) / batch_size


                mae_loss = self.vagt.mae_loss(h_out, label.cuda())
                #
                kld_z = 0
                kld_z += self.vagt.loss_KL(z_mean_posterior_forward, z_logvar_posterior_forward,
                                           z_mean_prior_forward, z_logvar_prior_forward) / batch_size

                loss = self.coeff * (-llh + self.vagt.beta * kld_z) + mae_loss

                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())
                llhs.append(llh.item())
                kld_zs.append(kld_z.item())
            meanloss = np.mean(losses)
            meanllh = np.mean(llhs)
            meanz = np.mean(kld_zs)
            self.epoch_losses.append(meanloss)
            print("Epoch {} : Average Loss: {} Loglikelihood: {} KL of z: {}, Beta: {}".format(
                epoch + 1, meanloss, meanllh, meanz, self.vagt.beta))
            self.loss['Epoch'] = epoch + 1
            self.loss['Avg_loss'] = meanloss
            self.loss['Llh'] = meanllh
            self.loss['KL_z'] = meanz
            self.logger.log_trainer(epoch + 1, self.loss)
            if (self.checkpoints_interval > 0
                    and (epoch + 1) % self.checkpoints_interval == 0):
                self.save_checkpoint(epoch, self.checkpoints)

            if (epoch + 1) % 1 == 0:
                self.vagt.beta = np.minimum((self.vagt.beta + 0.01) * np.exp(self.vagt.anneal_rate * (epoch + 1)),
                                             self.vagt.max_beta)
                # print("New Model Beta: {}".format(self.model.beta))

        print("Training is complete!")




def main():
    import os
    torch.cuda.set_device(0)
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

    parser = argparse.ArgumentParser()
    # GPU
    parser.add_argument('--gpu_id', type=int, default=0)
    # Dataset options
    parser.add_argument('--dataset', default= 'ETTh1', type=str)
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--root_path', type=str, default='./data/ETT/',
                        help='root path of the data file')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--seq_len', type=int, default=48, help='input sequence length of Informer encoder')
    parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')
    parser.add_argument('--features', type=str, default='M',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--cols', type=str, nargs='+', help='certain cols from the data files as the input features')
    parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
    parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--x_dim', type=int, default=7) #307
    parser.add_argument('--win_len', type=int, default=48,help = 'the same as seq_len')
    parser.add_argument('--coeff', type=float, default=0.5)
    #     # Model options for VAGT
    parser.add_argument('--z_dim', type=int, default=25)
    parser.add_argument('--h_dim', type=int, default=50)
    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--layer_xz', type=int, default=2)
    parser.add_argument('--layer_h', type=int, default=3)
    parser.add_argument('--q_len', type=int, default=1, help='for conv1D padding in qTransformer')
    parser.add_argument('--embd_h', type=int, default=128)
    parser.add_argument('--embd_s', type=int, default=256)
    parser.add_argument('--vocab_len', type=int, default=256) #256
    # Training options for VAGT
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--learning_rate', type=float, default=0.0002)
    parser.add_argument('--beta', type=float, default=0.0)
    parser.add_argument('--max_beta', type=float, default=1.0)
    parser.add_argument('--anneal_rate', type=float, default=0.05)

    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--checkpoints_interval', type=int, default=10)

    parser.add_argument('--checkpoints_path', type=str, default='model/debug')
    parser.add_argument('--checkpoints_file', type=str, default='')
    parser.add_argument('--log_path', type=str, default='log_trainer/debug')

    parser.add_argument('--log_file', type=str, default='')



    args = parser.parse_args()
    coeff = args.coeff
    # Set up GPU
    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')
    seed = 10
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    # For config checking
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    if not os.path.exists(args.checkpoints_path):
        os.makedirs(args.checkpoints_path)

    # TODO Saving path names, for updating later...
    if args.checkpoints_file == '':
        args.checkpoints_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_'\
                                'win_len-{}_q_len-{}_vocab_len-{}'.format(args.x_dim, args.z_dim, args.h_dim,
                                                                          args.layer_xz, args.layer_h, args.embd_h,
                                                                          args.n_head, args.win_len, args.q_len,
                                                                          args.vocab_len)
    if args.log_file == '':
        args.log_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_win_len-{}_'\
                        'q_len-{}_vocab_len-{}'.format(args.x_dim, args.z_dim, args.h_dim, args.layer_xz,
                                                       args.layer_h, args.embd_h, args.n_head, args.win_len,
                                                       args.q_len, args.vocab_len)

    # For training dataset
    # kpi_value_train = KpiReader(args.dataset_path)
            # train_loader = torch.utils.data.DataLoader(kpi_value_train, batch_size=args.batch_size,
            #                                            shuffle=True, num_workers=args.num_workers)

    # train_loader, val_loader, test_loader, scaler = get_dataloader(args,
    #                                                                normalizer=args.normalizer,
    #                                                                tod=args.tod, dow=False,
    #                                                                weather=False, single=False)
    data_parser = {
        'ETTh1': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTm1': {'data': 'ETTm1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    }
    if args.dataset in data_parser.keys():
        data_info = data_parser[args.dataset]
        args.data_path = data_info['data']
        args.target = data_info['T']
    args.detail_freq = args.freq
    args.freq = args.freq[-1:]
    _,train_loader = _get_data(args,'train')
    _, val_loader = _get_data(args, 'val')
    _, test_loader = _get_data(args, 'test')
    # For models init
    stackedvagt = StackedVAGT(layer_xz=args.layer_xz, layer_h=args.layer_h, n_head=args.n_head, x_dim=args.x_dim,
                              z_dim=args.z_dim, h_dim=args.h_dim, embd_h=args.embd_h, embd_s=args.embd_s,
                              beta=args.beta, q_len=args.q_len, vocab_len=args.vocab_len, win_len=args.seq_len,horizon=args.pred_len,
                              dropout=args.dropout, anneal_rate=args.anneal_rate, max_beta=args.max_beta,
                              device=device).to(device)
    names = []
    for name, parameters in stackedvagt.named_parameters():
        names.append(name)
        # print(name, ':', parameters, parameters.size())
    # Start train
    # trainer = Trainer(stackedvagt, kpi_value_train, train_loader, log_path=args.log_path, epochs=args.epochs,
    #     #                   log_file=args.log_file, batch_size=args.batch_size, learning_rate=args.learning_rate,
    #     #                   checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
    #     #                   checkpoints_interval=args.checkpoints_interval, device=device)
    trainer = Trainer(stackedvagt, train_loader, log_path=args.log_path, epochs=args.epochs,
                      log_file=args.log_file, batch_size=args.batch_size, learning_rate=args.learning_rate,coeff = args.coeff,
                      checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
                      checkpoints_interval=args.checkpoints_interval, device=device)
    trainer.load_checkpoint(args.start_epoch, trainer.checkpoints)
    trainer.train_model()


if __name__ == '__main__':
    import warnings
    warnings.filterwarnings("ignore")
    main()