import os
import platform
import argparse
import torch
import torch.nn as nn
from model.FourierGNN import FGN
from model.CNN import ConvNet
import time
import numpy as np
from utils.utils import save_model, evaluate, load_model
from dataloader_ncmapss import prepare_and_load_dataset

current_system = platform.system()
if current_system == "Linux":
    current_directory = os.path.dirname(os.path.realpath(__file__))
    os.chdir(current_directory)
    print("current_directory:", current_directory)

parser = argparse.ArgumentParser(description='RUL prediction')
parser.add_argument('--dataset', type=str, default='DS02', help='Dataset')
parser.add_argument('--sequence_length', type=int, default=100, help='Sequence length')
parser.add_argument('--s', type=int, default=1, help='Stride of filter')
parser.add_argument('--sampling', type=int, default=100, help='Subsampling rate')
parser.add_argument('--sub', type=int, default=1, help='Subsampling stride')
parser.add_argument('--model', type=str, default='FGN', help='Model type')
parser.add_argument('--feature_size', type=int, default=20, help='Feature size')
parser.add_argument('--embed_size', type=int, default=32, help='Hidden dimensions')
parser.add_argument('--hidden_size', type=int, default=128, help='Dense layer hidden dimensions')
parser.add_argument('--train_epochs', type=int, default=50, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('--pre_length', type=int, default=1, help='Prediction length')
parser.add_argument('--operating_condition', type=int, default=3, help='Operating condition')
parser.add_argument('--train', action='store_true', default=False, help='Train mode')
parser.add_argument('--test', type=str, default="test1", help='Test mode')
parser.add_argument('--decay_rate', type=float, default=0.5, help='Decay rate')
parser.add_argument('--exponential_decay_step', type=int, default=10, help='Decay step')
parser.add_argument('--validate_freq', type=int, default=1, help='Validation frequency')

args = parser.parse_args()

current_dir = os.path.dirname(os.path.abspath(__file__))
data_filedir = os.path.join(current_dir, 'NCMAPSS')
sample_dir_path = os.path.join(data_filedir, 'Samples_whole')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

result_train_file = os.path.join('output', args.dataset, args.model)
if not os.path.exists(result_train_file):
    os.makedirs(result_train_file)

def save_model(model, model_dir, seq_length, oc=2, exp_id=1):
    if model_dir is None:
        return
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    file_name = os.path.join(model_dir, f'{seq_length}_rul_{oc}_{exp_id}.pt')
    with open(file_name, 'wb') as f:
        torch.save(model, f)

def validate(model, test_dataloaders):
    model.eval()
    total_loss = 0
    total_cnt = 0
    all_preds = []
    all_trues = []

    for idx, vali_loader in enumerate(test_dataloaders):
        cnt = 0
        loss_total = 0
        preds = []
        trues = []

        if len(vali_loader) == 0:
            print(f'Test Set {idx} is empty, skipping...')
            continue

        for i, (x, y) in enumerate(vali_loader):
            cnt += 1
            total_cnt += 1
            y = y.float().to(device)
            x = x.float().to(device)
            forecast = model(x)
            y = y.permute(0, 2, 1).contiguous()

            loss = forecast_loss(forecast, y)
            loss_total += float(loss)
            total_loss += float(loss)

            forecast = forecast.detach().cpu().numpy()
            y = y.detach().cpu().numpy()
            preds.append(forecast)
            trues.append(y)

        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)

        all_preds.append(preds)
        all_trues.append(trues)

        score = evaluate(trues, preds)
        print(f'Test Set {idx}: MAPE {score[0]:7.9%}; Score {score[1]:7.9f}; RMSE {score[2]:7.9f}.')

    all_preds = np.concatenate(all_preds, axis=0)
    all_trues = np.concatenate(all_trues, axis=0)

    overall_score = evaluate(all_trues, all_preds)
    print(f'Overall Test Performance: MAPE {overall_score[0]:7.9%}; Score {overall_score[1]:7.9f}; RMSE {overall_score[2]:7.9f}.')
    model.train()
    return total_loss / total_cnt

def train(result_train_file, optimizer, exp_id):
    for epoch in range(args.train_epochs):
        epoch_start_time = time.time()
        model.train()
        loss_total = 0
        cnt = 0
        for index, (x, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            cnt += 1
            y = y.float().to(device)
            x = x.float().to(device)
            forecast = model(x)
            y = y.permute(0, 2, 1).contiguous()

            loss = forecast_loss(forecast, y)
            loss.backward()
            optimizer.step()
            loss_total += float(loss)

        if (epoch + 1) % args.exponential_decay_step == 0:
            my_lr_scheduler.step()
        if (epoch + 1) % args.validate_freq == 0:
            val_loss = validate(model, test_dataloader)

        print('| end of epoch {:3d} | time: {:5.2f}s | train_total_loss {:5.4f} | val_loss {:5.4f}'.format(
            epoch, (time.time() - epoch_start_time), np.sqrt(loss_total / cnt), np.sqrt(val_loss)))
        save_model(model, result_train_file, args.sequence_length, oc=args.operating_condition, exp_id=exp_id)

    save_model(model, result_train_file, args.sequence_length, oc=args.operating_condition, exp_id=exp_id)


if __name__ == '__main__':
    if args.train:
        for exp_id in range(1, 5):
            train_dataloader, test_dataloader = prepare_and_load_dataset(
                data_dir=data_filedir, data_file='N-CMAPSS_DS02-006.h5',
                sequence_length=args.sequence_length, stride=1, sampling=args.sampling,
                units_file="File_DevUnits_TestUnits.csv", batch_size=args.batch_size,
                kmeans=True, kmeans_clusters=4, operating_condition=args.operating_condition
            )
            print(f"Training data shape: {len(train_dataloader)}")
            if args.model == 'FGN':
                model = FGN(pre_length=args.pre_length, embed_size=args.embed_size,
                            feature_size=args.feature_size, seq_length=args.sequence_length,
                            hidden_size=args.hidden_size).to(device)
            elif args.model == 'CNN':
                model = ConvNet(embed_size=args.embed_size, hidden_size=args.hidden_size,
                                seq_length=args.sequence_length).to(device)
            my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
            my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
            forecast_loss = nn.MSELoss(reduction='mean').to(device)
            train(result_train_file, my_optim, exp_id)
    else:
        test_results = {}
        for sequence_length in [100, 200]:
            test_results[sequence_length] = {'true_values': {}, 'pred_values': {}}
            print(f"Sequence Length: {sequence_length}")

            for operating_condition in range(4):
                print(f"Operating Condition: {operating_condition}")
                train_dataloader, test_dataloader = prepare_and_load_dataset(
                    data_dir=data_filedir, data_file='N-CMAPSS_DS02-006.h5',
                    sequence_length=sequence_length, stride=1, sampling=args.sampling,
                    units_file="File_DevUnits_TestUnits.csv", batch_size=args.batch_size,
                    kmeans=True, kmeans_clusters=4, operating_condition=operating_condition
                )

                model = load_model(result_train_file, sequence_length, operating_condition, device)
                forecast_loss = nn.MSELoss(reduction='mean').to(device)

                model.eval()
                total_loss = 0
                total_cnt = 0
                all_preds = []
                all_trues = []

                test_results[sequence_length]['pred_values'][f'operating_condition_{operating_condition}'] = {}
                test_results[sequence_length]['true_values'][f'operating_condition_{operating_condition}'] = {}

                for idx, vali_loader in enumerate(test_dataloader):
                    if len(vali_loader) == 0:
                        print(f'Test Set {idx} is empty, skipping...')
                        continue

                    cnt = 0
                    preds = []
                    trues = []

                    for i, (x, y) in enumerate(vali_loader):
                        cnt += 1
                        total_cnt += 1
                        x, y = x.float().to(device), y.float().to(device)
                        forecast = model(x)
                        y = y.permute(0, 2, 1).contiguous()

                        loss = forecast_loss(forecast, y)
                        total_loss += float(loss)

                        forecast = forecast.detach().cpu().numpy()
                        y = y.detach().cpu().numpy()
                        preds.append(forecast)
                        trues.append(y)

                    preds = np.concatenate(preds, axis=0)
                    trues = np.concatenate(trues, axis=0)

                    test_results[sequence_length]['pred_values'][f'operating_condition_{operating_condition}'][f'unit_{idx}'] = preds
                    test_results[sequence_length]['true_values'][f'operating_condition_{operating_condition}'][f'unit_{idx}'] = trues

                    score = evaluate(trues, preds)
                    print(f'Test Set {idx}: MAPE {score[0]:7.9%}; Score {score[1]:7.9f}; RMSE {score[2]:7.9f}.')

                all_preds = np.concatenate(all_preds, axis=0)
                all_trues = np.concatenate(all_trues, axis=0)

                overall_score = evaluate(all_trues, all_preds)
                print(f'Overall Test Performance: MAPE {overall_score[0]:7.9%}; Score {overall_score[1]:7.9f}; RMSE {overall_score[2]:7.9f}.')

