# -*-coding: utf-8 -*-
import argparse
import math
import time
import numpy


from models import LSTNet

import torch.nn as nn

from Optim import Optim
from utils import *
import os


def evaluate(data, X, Y, model, evaluateL2, batch_size):
    model.eval()
    total_loss = 0
    total_loss_l1 = 0
    n_samples = 0
    predict = None
    test = None
    with torch.no_grad():
        for X, Y in data.get_batches(X, Y, batch_size, False):
            output = model(X)
            if predict is None:
                predict = output
                test = Y
            else:
                predict = torch.cat((predict, output))
                test = torch.cat((test, Y))

            scale = data.scale.expand(output.size(0), data.m)
            total_loss += evaluateL2(output * scale, Y * scale).item()
            n_samples += (output.size(0) * data.m)
    rse = math.sqrt(total_loss / n_samples) / data.rse
    predict = predict.data.cpu().numpy()
    Ytest = test.data.cpu().numpy()
    sigma_p = predict.std(axis=0)
    sigma_g = Ytest.std(axis=0)
    mean_p = predict.mean(axis=0)
    mean_g = Ytest.mean(axis=0)
    index = (sigma_g != 0)
    correlation = ((predict - mean_p) * (Ytest - mean_g)).mean(axis=0) / (sigma_p * sigma_g)
    correlation = (correlation[index]).mean()

    return rse, correlation


def train(data, X, Y, model, criterion, optim, batch_size):
    model.train()
    total_loss = 0
    n_samples = 0
    for X, Y in data.get_batches(X, Y, batch_size, True):
        model.zero_grad()
        output = model(X)
        scale = data.scale.expand(output.size(0), data.m)
        loss = criterion(output * scale, Y * scale)
        loss.backward()
        optim.step()
        total_loss += loss.item()
        n_samples += (output.size(0) * data.m)
    return total_loss / n_samples


print("Starting")

parser = argparse.ArgumentParser(description='PyTorch Time series forecasting')
parser.add_argument('--data', type=str, default='data/exchange_rate.txt',
                    help='location of the data file')
parser.add_argument('--dataName', type=str, default='stock',
                    help='name of the data file')
parser.add_argument('--model', type=str, default='LSTNet',
                    help='')
parser.add_argument('--hidCNN', type=int, default=50,
                    help='number of CNN hidden units') 
parser.add_argument('--bond_dim', type=int, default=50,
                    help='number of RNN hidden units')
parser.add_argument('--window', type=int, default=24 * 7,
                    help='window size') 
parser.add_argument('--CNN_kernel', type=int, default=6,
                    help='the kernel size of the CNN layers') 
parser.add_argument('--dilation', type=int, default=2,
                    help='the dilation rate') 
parser.add_argument('--input_channel', type=int, default=24 * 7,
                    help='The window size of the AR component') 
parser.add_argument('--skip_channel', type=int, default=24 * 7,
                    help='The window size of the AR component')

parser.add_argument('--clip', type=float, default=10.,
                    help='gradient clipping')
parser.add_argument('--decay', type=float, default=1.0,
                    help='gradient decay')
parser.add_argument('--start_epoch', type=int, default=200,
                    help='gradient start epoch')
parser.add_argument('--lrstop', type=float, default=1e-4,
                    help='the stop of lr decay')

parser.add_argument('--epochs', type=int, default=100,
                    help='upper epoch limit')  
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                    help='batch size')  
parser.add_argument('--dropout', type=float, default=0.0,
                    help='dropout applied to layers (0 = no dropout)') 
parser.add_argument('--wlocal', type=float, default=5.0,
                    help='the weight of local') 
parser.add_argument('--wglobal', type=float, default=0.5,
                    help='the weight of global') 
parser.add_argument('--seed', type=int, default=54321,
                    help='random seed')
parser.add_argument('--seedset', type=bool, default=True) 
parser.add_argument('--gpu', type=str, default=None) 
parser.add_argument('--save', type=str, default='models/exchange_rate.pt',
                    help='path to save the final model')
parser.add_argument('--cuda', type=str, default=False)
parser.add_argument('--optim', type=str, default='adam')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--std', type=float, default=1e-6)
parser.add_argument('--horizon', type=int, default=3) 
parser.add_argument('--L1Loss', type=bool, default=False) 
parser.add_argument('--normalize', type=int, default=2) 
parser.add_argument('--output_fun', type=str, default='Linear') 
parser.add_argument('--order', type=str, default='RK4') 
args = parser.parse_args()

args.cuda = args.gpu is not None
if args.cuda:
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

if args.seedset:
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")
        else:
            torch.cuda.manual_seed(args.seed)

Data = Data_Utility(args.data, 0.6, 0.2, args.cuda, args.horizon, args.window, args.normalize)

model = eval(args.model).Model(args, Data)

if args.cuda:
    model.cuda()

nParams = sum([p.nelement() for p in model.parameters()])
print('* number of parameters: %d' % nParams)

if args.L1Loss:
    criterion = nn.L1Loss(reduction='sum')
else:
    criterion = nn.MSELoss(reduction='sum')
evaluateL2 = nn.MSELoss(reduction='sum')

if args.cuda:
    criterion = criterion.cuda()
    evaluateL2 = evaluateL2.cuda()

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

best_val = 2000000000000000000
optim = Optim(model.parameters(), args.optim, args.lr, args.clip, lr_decay=args.decay, start_decay_at=args.start_epoch, lr_stop=args.lrstop)

try:
    print('Start training....')
    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train_loss = train(Data, Data.train[0], Data.train[1], model, criterion, optim, args.batch_size)
        val_loss, val_corr = evaluate(Data, Data.valid[0], Data.valid[1], model, evaluateL2,args.batch_size)
        optim.update_learning_rate(val_loss, epoch)
        print('| end of epoch {:3d} | time: {:5.2f}s | train_loss {:5.4f} | valid rse {:5.4f} | valid corr  {:5.4f} '.format(
                epoch, (time.time() - epoch_start_time), train_loss, val_loss, val_corr))
        # Save the model if the validation loss is the best we've seen so far.

        if val_loss < best_val:
            with open("{}_{}.pt".format(args.save, args.horizon), 'wb') as f:
                torch.save(model, f)
            best_val = val_loss
        if epoch % 5 == 0:
            test_acc, test_corr = evaluate(Data, Data.test[0], Data.test[1], model, evaluateL2,args.batch_size)
            print("test rse {:5.4f} | test corr {:5.4f}".format(test_acc, test_corr))

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

with open("{}_{}.pt".format(args.save, args.horizon), 'rb') as f:
    model = torch.load(f)
test_acc, test_corr= evaluate(Data, Data.test[0], Data.test[1], model, evaluateL2, args.batch_size)


argsDict = args.__dict__
with open('logs/{}_horizon{}.txt'.format(args.dataName, args.horizon), 'a') as f:
    f.writelines('------------------ start ------------------' + '\n')
    for eachArg, value in argsDict.items():
        f.writelines(eachArg + ' : ' + str(value) + '\n')
    f.writelines('------------------- end -------------------')
    f.write("test rse {:5.4f} | test corr {:5.4f}".format(test_acc, test_corr)) 
    f.write('\n\n')

print("test rse {:5.4f} | test corr {:5.4f}".format(test_acc, test_corr))

