import os
import sys
import platform

current_system = platform.system()
if current_system == "Linux":
    # current_directory = os.getcwd()
    current_directory = os.path.dirname(os.path.realpath(__file__))
    os.chdir(current_directory)
    print("current_directory:", current_directory)

import argparse
import torch
import torch.nn as nn
from model.FourierGNN import FGN
from model.CNN import ConvNet

import time
import os
import numpy as np
import json
import random
import sys
import pandas as pd

print(sys.path)
from utils.utils import save_model, load_model, evaluate
from dataloader import Dataset_RUL

# torch.manual_seed(42)

# main settings can be seen in markdown file (README.md)
parser = argparse.ArgumentParser(description='fourier graph network for multivariate time series forecasting')
parser.add_argument('--data', type=str, default='RUL', help='data set')
parser.add_argument('--dataset', type=str, default='002', help='data set')
parser.add_argument('--model', type=str, default='FGN', help='models')
parser.add_argument('--feature_size', type=int, default='14', help='feature size')
parser.add_argument('--seq_length', type=int, default=[20, 40, 60, 80, 100, 120, 140, 160, 180], help='inout length')
parser.add_argument('--pre_length', type=int, default=1, help='predict length')
parser.add_argument('--embed_size', type=int, default=128, help='hidden dimensions')
parser.add_argument('--hidden_size', type=int, default=128, help='hidden dimensions')
parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
parser.add_argument('--batch_size', type=int, default=256, help='input data batch size')
parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
parser.add_argument('--exponential_decay_step', type=int, default=10)
parser.add_argument('--validate_freq', type=int, default=1)
parser.add_argument('--early_stop', type=bool, default=False)
parser.add_argument('--decay_rate', type=float, default=0.5)
parser.add_argument('--train', type=bool, default=True, help='train or test')
parser.add_argument('--test', type=str, default="test2", help='train or test')
parser.add_argument('--exp_time', type=str, default="001", help='train or test')

args = parser.parse_args()
print(f'Training configs: {args}')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_model(model_dir, seq_length, test_id, device):
    if not model_dir:
        return
    file_name = os.path.join(model_dir, str(seq_length) + '_rul_' + str(test_id) + '.pt')
    print(file_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(file_name):
        return
    with open(file_name, 'rb') as f:
        model = torch.load(f, map_location=device)

    return model

def validate(model, vali_loader):
    model.eval()
    cnt = 0
    loss_total = 0
    preds = []
    trues = []
    for i, (x, y) in enumerate(vali_loader):
        cnt += 1
        y = y.float().to(device)
        x = x.float().to(device)
        forecast = model(x)
        y = y.permute(0, 2, 1).contiguous()

        # print(y.size(), forecast.size())
        loss = forecast_loss(forecast, y)
        loss_total += float(loss)
        forecast = forecast.detach().cpu().numpy()  # .squeeze()
        y = y.detach().cpu().numpy()  # .squeeze()
        preds.append(forecast)
        trues.append(y)
    preds = np.array(preds)
    trues = np.array(trues)
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    score = evaluate(trues, preds)
    print(f'RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
    model.train()
    return loss_total / cnt


def test(seq_length, test_loader, test_id, operating_condition, device):
    model = load_model(result_train_file, seq_length, operating_condition, device)
    model.eval()
    preds = []
    trues = []
    sne = []
    for index, (x, y) in enumerate(test_loader):
        y = y.float().to(device)
        x = x.float().to(device)
        forecast = model(x)
        y = y.permute(0, 2, 1).contiguous()
        forecast = forecast.detach().cpu().numpy()  # .squeeze()
        y = y.detach().cpu().numpy()  # .squeeze()
        preds.append(forecast)
        trues.append(y)
    preds = np.array(preds)
    trues = np.array(trues)
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    score = evaluate(trues, preds)
    df = pd.DataFrame({'prediction': preds.reshape(preds.shape[0],)}, index=np.array(test_id))
    return score[2], preds, trues, df

def train(result_train_file, optimizer, condition_id, exp_id):
    for epoch in range(args.train_epochs):

        epoch_start_time = time.time()
        model.train()
        loss_total = 0
        cnt = 0

        for index, (x, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            cnt += 1
            y = y.float().to(device)
            x = x.float().to(device)
            forecast = model(x)
            y = y.permute(0, 2, 1).contiguous()

            loss = forecast_loss(forecast, y)
            loss.backward()
            my_optim.step()
            loss_total += float(loss)


        if (epoch + 1) % args.exponential_decay_step == 0:
            my_lr_scheduler.step()


        if (epoch + 1) % args.validate_freq == 0:
            val_loss = validate(model, val_dataloader)

        print('| end of epoch {:3d} | time: {:5.2f}s | train_total_loss {:5.4f} | val_loss {:5.4f}'.format(
            epoch, (time.time() - epoch_start_time), np.sqrt(loss_total / cnt), np.sqrt(val_loss)))


    train_rmse = np.sqrt(loss_total / cnt)
    test_rmse = np.sqrt(val_loss)

    return train_rmse, test_rmse


# create output dir
result_train_file = os.path.join('output', args.data, args.dataset, args.model, args.exp_time)
result_test_file = os.path.join('output', args.data, args.dataset, 'test')
if not os.path.exists(result_train_file):
    os.makedirs(result_train_file)
if not os.path.exists(result_test_file):
    os.makedirs(result_test_file)


def get_subset_performance(grouped_preds, grouped_trues, result, time_windows=[30, 60, 90, 120]):
    split_indices = Dataset_RUL(data=args.dataset, seq_len=30).get_subset_unit_index(time_windows)
    subset_index = 0

    for subset, (start_index, end_index) in enumerate(
            zip(split_indices[subset_index:-1], split_indices[1 + subset_index:])):
        print(f'Subset {subset + 1}, start: {start_index}, end:{end_index}, number of units: {end_index - start_index}')
        subset_trues = grouped_trues[start_index:end_index]
        for i in range(subset_index + 1):
            subset_preds = grouped_preds[i, start_index:end_index]
            score = evaluate(subset_preds, subset_trues)

            print(f'TEST Subset {subset + 1} : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
            # print("=====================================")
        subset_ensemble = result[start_index:end_index]
        score = evaluate(subset_trues, subset_ensemble)
        print(f'TEST Subset {subset + 1} Ensemble: MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
        print("=====================================")
        subset_index += 1


if __name__ == '__main__':

    if args.train:

        for seq in args.seq_length:
            train_rmse_list = []
            test_rmse_list = []
            # for condition_id in range(6):
            for condition_id in range(2,3):

                for exp_id in range(5):
                    train_dataloader, val_dataloader, _ = Dataset_RUL(data=args.dataset,
                                                                      seq_len=seq,
                                                                      operating_condition=condition_id,
                                                                      renormalize=True).get_dataloader()


                    if args.model == 'FGN':

                        model = FGN(pre_length=args.pre_length,
                                    embed_size=args.embed_size,
                                    feature_size=args.feature_size,
                                    seq_length=seq,
                                    hidden_size=args.hidden_size).to(device)
                        my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
                        my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)

                    elif args.model == 'CNN':
                        model = ConvNet(embed_size=args.embed_size, hidden_size=args.hidden_size, seq_length=seq).to(device)
                        my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
                        # my_optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
                        # my_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=my_optim, T_max=32)
                        my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
                    forecast_loss = nn.MSELoss(reduction='mean').to(device)
                    # print("initial weights: ", model.state_dict())

                    train_rmse, test_rmse = train(result_train_file, my_optim, condition_id, exp_id)
                    train_rmse_list.append(train_rmse)
                    test_rmse_list.append(test_rmse)
                print(train_rmse_list, test_rmse_list)
                print(
                    f'Average train RMSE: {np.mean(train_rmse_list):7.9f}; Average test RMSE: {np.mean(test_rmse_list):7.9f}.')

            keys = ['train_rmse_list', 'test_rmse_list']
            values = [train_rmse_list, test_rmse_list]

            my_dict = dict(zip(keys, values))

            args_dict = vars(args)
            my_dict.update(args_dict)

            file_name = f'{args.dataset}_results_{seq}.json'
            file_path = os.path.join(result_train_file, file_name)

            with open(file_path, 'w') as f:
                json.dump(my_dict, f)

            with open(file_path, 'r') as f:
                loaded_data = json.load(f)

            print("Load dict:", loaded_data)

    elif args.test == "test2":


        result_df = pd.DataFrame()

        print("Validation error while training:")
        for time_window in args.seq_length:
            df_list = []

            rmse_list = []
            print("=====================================")

            for operating_condition in range(6):
                _, test_dataloader, test_id = Dataset_RUL(data=args.dataset, seq_len=time_window,
                                                          operating_condition=operating_condition,
                                                          renormalize=True).get_dataloader()

                test_rmse, test_preds, test_trues, df = test(time_window, test_dataloader, test_id, operating_condition,
                                                             device)

                print(time_window, f'operating_condition: {operating_condition}, Validation RMSE: {test_rmse:7.9f}.')
                df_list.append(df)
                rmse_list.append(test_rmse)

            df_all = pd.concat(df_list, axis=1)

            train_df, test_df = Dataset_RUL(data=args.dataset, seq_len=time_window, operating_condition=0,
                                            kmeans=False).load_data()
            id_of_test_df_sorted, _ = Dataset_RUL(data=args.dataset, seq_len=time_window, operating_condition=0,
                                            kmeans=False).check_length()

            test_label = [test_df[test_df['id'] == id]['RUL'].values[-1] for id in id_of_test_df_sorted if
                          len(test_df[test_df['id'] == id]) > time_window]
            test_label = np.array(test_label).astype(np.float32)
            print("test_label: ", test_label)
            df_all['label'] = test_label
            df_all['mean_prediction'] = df_all.iloc[:, :6].mean(axis=1, skipna=True)

            print(df_all.head())
            score = evaluate(df_all['label'].values, df_all['mean_prediction'].values)
            print(f'TEST RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
            rmse_list.append(score[2])

            df_all.to_csv(f'output/RUL/002/FGN/002_{time_window}_sorted.csv', index=True)
            result_df[time_window] = rmse_list
            result_df.to_csv(f'output/RUL/002/FGN/result_{time_window}_df.csv', index=True)

