import os
import sys
import platform

current_system = platform.system()
if current_system == "Linux":
    # current_directory = os.getcwd()
    current_directory = os.path.dirname(os.path.realpath(__file__))
    os.chdir(current_directory)
    print("current_directory:", current_directory)

import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from model.FourierGNN import FGN
from model.CNN import ConvNet

import time
import os
import numpy as np
import json

import sys

print(sys.path)
from utils.utils import save_model, evaluate
from dataloader import Dataset_RUL
from torchsummary import summary

# torch.manual_seed(42)

# main settings can be seen in markdown file (README.md)
parser = argparse.ArgumentParser(description='fourier graph network for multivariate time series forecasting')
parser.add_argument('--data', type=str, default='RUL', help='data set')
parser.add_argument('--dataset', type=str, default='003', help='data set')
parser.add_argument('--model', type=str, default='FGN', help='models')
parser.add_argument('--feature_size', type=int, default='14', help='feature size')
parser.add_argument('--seq_length', type=int, default=[30, 60, 90, 120, 150, 180, 210, 240], help='inout length')
parser.add_argument('--pre_length', type=int, default=1, help='predict length')
parser.add_argument('--embed_size', type=int, default=128, help='hidden dimensions')
parser.add_argument('--hidden_size', type=int, default=128, help='hidden dimensions')
parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
parser.add_argument('--batch_size', type=int, default=256, help='input data batch size')
parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
parser.add_argument('--exponential_decay_step', type=int, default=10)
parser.add_argument('--validate_freq', type=int, default=1)
parser.add_argument('--decay_rate', type=float, default=0.5)
parser.add_argument('--train', type=bool, default=False, help='train or test')
parser.add_argument('--test', type=str, default="test", help='train or test')

args = parser.parse_args()
print(f'Training configs: {args}')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




def load_model(model_dir, seq_length, device, exp_id=0):
    if not model_dir:
        return
    file_name = os.path.join(model_dir, str(seq_length) + '_rul_' + str(exp_id) + '.pt')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(file_name):
        return
    with open(file_name, 'rb') as f:
        model = torch.load(f, map_location=device)

    return model

def validate(model, vali_loader):
    model.eval()
    cnt = 0
    loss_total = 0
    preds = []
    trues = []
    for i, (x, y) in enumerate(vali_loader):
        cnt += 1
        y = y.float().to(device)
        x = x.float().to(device)
        forecast = model(x)
        y = y.permute(0, 2, 1).contiguous()

        # print(y.size(), forecast.size())
        loss = forecast_loss(forecast, y)
        loss_total += float(loss)
        forecast = forecast.detach().cpu().numpy()  # .squeeze()
        y = y.detach().cpu().numpy()  # .squeeze()
        preds.append(forecast)
        trues.append(y)
    preds = np.array(preds)
    trues = np.array(trues)
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    score = evaluate(trues, preds)
    print(f'RAW : MAPE {score[0]:7.9%}; Score {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
    model.train()
    return loss_total / cnt


def test(seq_length, test_loader, device, exp_id=0):
    model = load_model(result_train_file, seq_length, device, exp_id)
    model.eval()
    preds = []
    trues = []
    sne = []
    for index, (x, y) in enumerate(test_loader):
        y = y.float().to(device)
        x = x.float().to(device)
        forecast = model(x)
        y = y.permute(0, 2, 1).contiguous()
        forecast = forecast.detach().cpu().numpy()  # .squeeze()
        y = y.detach().cpu().numpy()  # .squeeze()
        preds.append(forecast)
        trues.append(y)
    preds = np.array(preds)
    trues = np.array(trues)
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    score = evaluate(trues, preds)
    return score[2], preds, trues


def train(result_train_file,  exp_id):
    for epoch in range(args.train_epochs):
        epoch_start_time = time.time()
        model.train()
        loss_total = 0
        cnt = 0
        for index, (x, y) in enumerate(train_dataloader):
            cnt += 1
            y = y.float().to(device)
            x = x.float().to(device)
            forecast = model(x)
            y = y.permute(0, 2, 1).contiguous()

            loss = forecast_loss(forecast, y)
            loss.backward()
            my_optim.step()
            loss_total += float(loss)

        if (epoch + 1) % args.exponential_decay_step == 0:
            my_lr_scheduler.step()
        if (epoch + 1) % args.validate_freq == 0:
            val_loss = validate(model, val_dataloader)

        print('| end of epoch {:3d} | time: {:5.2f}s | train_total_loss {:5.4f} | val_loss {:5.4f}'.format(
            epoch, (time.time() - epoch_start_time), np.sqrt(loss_total / cnt), np.sqrt(val_loss)))
    save_model(model, result_train_file, args.seq_length, exp_id)
    train_rmse = np.sqrt(loss_total / cnt)
    test_rmse = np.sqrt(val_loss)
    return train_rmse, test_rmse





def get_subset_performance(grouped_preds, grouped_trues, result, time_windows=[30, 60, 90, 120]):
    split_indices = Dataset_RUL(data=args.dataset, seq_len=30).get_subset_unit_index(time_windows)
    subset_index = 0

    for subset, (start_index, end_index) in enumerate(
            zip(split_indices[subset_index:-1], split_indices[1 + subset_index:])):
        print(f'Subset {subset + 1}, start: {start_index}, end:{end_index}, number of units: {end_index - start_index}')
        subset_trues = grouped_trues[start_index:end_index]
        for i in range(subset_index + 1):
            subset_preds = grouped_preds[i, start_index:end_index]
            score = evaluate(subset_preds, subset_trues)

            print(f'TEST Subset {subset + 1} : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
            # print("=====================================")
        subset_ensemble = result[start_index:end_index]
        score = evaluate(subset_trues, subset_ensemble)
        print(f'TEST Subset {subset + 1} Ensemble: MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
        print("=====================================")
        subset_index += 1


if __name__ == '__main__':
    # create output dir
    result_train_file = os.path.join('output', args.data, args.dataset, args.model)
    if not os.path.exists(result_train_file):
        os.makedirs(result_train_file)

    if args.train:
        train_rmse_list = []
        test_rmse_list = []

        for seq in args.seq_length:

            for exp_id in range(1, 6):
                train_dataloader, val_dataloader, _ = Dataset_RUL(data=args.dataset,
                                                                  seq_len=seq).get_dataloader()

                if args.model == 'FGN':

                    model = FGN(pre_length=args.pre_length,
                                embed_size=args.embed_size,
                                feature_size=args.feature_size,
                                seq_length=seq,
                                hidden_size=args.hidden_size).to(device)
                    summary(model)
                elif args.model == 'CNN':
                    model = ConvNet(embed_size=args.embed_size, hidden_size=args.hidden_size, seq_length=seq).to(device)

                my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
                my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
                forecast_loss = nn.MSELoss(reduction='mean').to(device)

                train_rmse, test_rmse = train(result_train_file, exp_id)
                train_rmse_list.append(train_rmse)
                test_rmse_list.append(test_rmse)
            print(train_rmse_list, test_rmse_list)
            print(
                f'Average train RMSE: {np.mean(train_rmse_list):7.9f}; Average test RMSE: {np.mean(test_rmse_list):7.9f}.')

            keys = ['train_rmse_list', 'test_rmse_list']
            values = [train_rmse_list, test_rmse_list]

            my_dict = dict(zip(keys, values))

            args_dict = vars(args)
            my_dict.update(args_dict)

            file_name = f'{args.dataset}_results_{seq}.json'
            file_path = os.path.join(result_train_file, file_name)

            with open(file_path, 'w') as f:
                json.dump(my_dict, f)

            with open(file_path, 'r') as f:
                loaded_data = json.load(f)

            print("Load dict:", loaded_data)


    elif args.test == "test":
        time_window_list = args.seq_length

        print("Validation error while training:")
        for time_window in time_window_list:
            print("=====================================")
            _, val_dataloader, _ = Dataset_RUL(data=args.dataset, seq_len=time_window).get_dataloader()
            test_rmse, test_preds, test_trues = test(time_window, val_dataloader, device, exp_id=0)
            print(time_window,  f'Validation RMSE: {test_rmse:7.9f}.')


        test_dataloader_list = Dataset_RUL(data=args.dataset, seq_len=30).get_subtest(time_window_list)
        print("test whole dataloader:")

        preds = []
        trues = []
        for index, dataloader in enumerate(test_dataloader_list):
            # if index<4:
            sample_batch = next(iter(dataloader))
            time_window_size = sample_batch[0].size(1)
            test_rmse, test_preds, test_trues = test(time_window_size, dataloader, device)
            print(time_window_size, sample_batch[0].size(), f'Test RMSE: {test_rmse:7.9f}.')

            preds.append(test_preds)
            trues.append(test_trues)
        print("=====================================")

        # group_indices = [(0, 4), (4, 7), (7, 9), (9, 10)]
        group_indices = Dataset_RUL(data=args.dataset, seq_len=30).get_group_index(time_window_list)
        length_of_test_df_sorted, length_of_train_df_sorted = Dataset_RUL(data=args.dataset, seq_len=30).check_length()

        grouped_preds = []
        for start, end in group_indices:
            group = preds[start:end]
            total_length = sum([len(x) for x in group])
            padding_length = max(0, len(length_of_test_df_sorted) - total_length)
            concatenated_array = np.concatenate(group).flatten()
            padded_array = np.pad(concatenated_array, (padding_length, 0), 'constant')
            grouped_preds.append(padded_array)
        grouped_preds = np.array(grouped_preds)  # (4, 100)

        grouped_trues = trues[0:len(time_window_list)]
        grouped_trues = np.concatenate(grouped_trues).flatten()

        # 显示预测结果和std
        std_nonzero = np.nanstd(np.where(grouped_preds == 0, np.nan, grouped_preds), axis=0)
        preds_and_trues = np.concatenate((grouped_preds, grouped_trues.reshape(1, -1)), axis=0)

        print("=====================================")
        score = evaluate(preds_and_trues[0,:], preds_and_trues[-1,:])
        print(f'Individual TEST RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
        print("=====================================")


        ensemble_method = "mean"
        # ensemble_method = "weighted_mean"
        # ensemble_method = "best"

        if ensemble_method == "mean":
            nonzero_count = np.count_nonzero(grouped_preds, axis=0)
            sum_values = np.sum(grouped_preds, axis=0, dtype=np.float32)
            result = np.divide(sum_values, nonzero_count, out=np.zeros_like(sum_values), where=nonzero_count != 0)
            score = evaluate(grouped_trues, result)
            print(f'TEST RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
            print("=====================================")
            get_subset_performance(grouped_preds, grouped_trues, result, time_windows=time_window_list)
            # subset_performance(grouped_preds, grouped_trues, result)

        elif ensemble_method == "weighted_mean":
            training_rmse_lists = []
            for seq_length in time_window_list:
                training_loader, _, _ = Dataset_RUL(data='001', seq_len=seq_length).get_dataloader()
                training_rmse, _, _ = test(seq_length, training_loader, device)
                training_rmse_lists.append(training_rmse)

            print(training_rmse_lists)
            weight_list = []
            for i in range(4):
                training_list = training_rmse_lists[:i + 1]
                weights = [1 / error for error in training_list]

                weights_normalized = np.array(weights) / np.sum(weights)
                padding_length = 4 - len(weights_normalized)
                # padded_array = np.pad(weights_normalized, (padding_length, 0), 'constant')
                padded_array = np.pad(weights_normalized, (0, padding_length), 'constant')

                weight_list.append(padded_array)
            weight_array = np.array(weight_list)

            print(weight_array)
            nonzero_count = np.count_nonzero(grouped_preds, axis=0)
            print(nonzero_count)

            results = []
            for preds in grouped_preds.T:
                nonzero_count = np.count_nonzero(preds)
                sum_values = np.sum(preds * weight_array[nonzero_count - 1], axis=0, dtype=np.float32)
                results.append(sum_values)
            result = np.array(results)
            score = evaluate(grouped_trues, result)
            print(f'TEST RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')

        elif ensemble_method == "best":
            weight_array = np.eye(len(time_window_list))

            nonzero_count = np.count_nonzero(grouped_preds, axis=0)
            print(nonzero_count)

            results = []
            for preds in grouped_preds.T:
                nonzero_count = np.count_nonzero(preds)
                sum_values = np.sum(preds * weight_array[nonzero_count - 1], axis=0, dtype=np.float32)
                results.append(sum_values)
            result = np.array(results)
            score = evaluate(grouped_trues, result)
            print(f'TEST RAW : MAPE {score[0]:7.9%}; MAE {score[1]:7.9f}; RMSE {score[2]:7.9f}.')
