#!/usr/bin/env python3


from __future__ import annotations

import argparse
import random
import time
from pathlib import Path

import numpy as np
import torch

from exp.exp_main import ExpMain


def str2bool(value: str) -> bool:
    if isinstance(value, bool):
        return value
    if value.lower() in {'true', '1', 'yes', 'y'}:
        return True
    if value.lower() in {'false', '0', 'no', 'n'}:
        return False
    raise argparse.ArgumentTypeError(f'Cannot interpret {value} as bool.')


def create_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description='FACT anonymized release – Solar forecasting example',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument('--is_training', type=int, default=1,
                        help='1=train then test, 0=test only')
    parser.add_argument('--seed', type=int, default=2021, help='Random seed')
    parser.add_argument('--model_id', type=str,
                        default='FACT_Solar', help='Run identifier')
    parser.add_argument('--itr', type=int, default=1,
                        help='How many repeated runs to execute')
    parser.add_argument('--des', type=str, default='Exp',
                        help='Short description tag')

    parser.add_argument('--data', type=str, default='Solar', choices=['Solar'])
    parser.add_argument('--root_path', type=str, default='./datasets/solar/',
                        help='Folder that stores solar_AL.txt')
    parser.add_argument('--data_path', type=str,
                        default='solar_AL.txt', help='Solar dataset filename')
    parser.add_argument('--features', type=str, default='M', choices=['M'])
    parser.add_argument('--target', type=str, default='OT')
    parser.add_argument('--freq', type=str, default='t',
                        help='Time encoding frequency tag (unused for Solar)')

    parser.add_argument('--model', type=str, default='FACT', choices=['FACT'])
    parser.add_argument('--seq_len', type=int, default=96)
    parser.add_argument('--label_len', type=int, default=48)
    parser.add_argument('--pred_len', type=int, default=96)
    parser.add_argument('--enc_in', type=int, default=137)
    parser.add_argument('--dec_in', type=int, default=137)
    parser.add_argument('--c_out', type=int, default=137)
    parser.add_argument('--d_model', type=int, default=128)
    parser.add_argument('--d_ff', type=int, default=512)
    parser.add_argument('--e_layers', type=int, default=2)
    parser.add_argument('--n_heads', type=int, default=8)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--activation', type=str, default='gelu')
    parser.add_argument('--encoder_type', type=str,
                        default='transformer', choices=['transformer'])

    parser.add_argument('--use_dynfbd', type=str2bool, default=True)
    parser.add_argument('--dynfbd_variant', type=str,
                        default='gauss', choices=['gauss'])
    parser.add_argument('--use_fselector', type=str2bool, default=True)
    parser.add_argument('--target_unit_k', type=int, default=128)
    parser.add_argument('--use_adaptive_fusion', type=str2bool, default=True)
    parser.add_argument('--use_channel_mixing', type=str2bool, default=True)
    parser.add_argument('--use_guided_gating', type=str2bool, default=True)
    parser.add_argument('--use_revin', type=str2bool, default=True)
    parser.add_argument('--use_norm', type=int, default=0)
    parser.add_argument('--lambda_coh', type=float, default=0.01)
    parser.add_argument('--lambda_phase', type=float, default=0.01)

    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--learning_rate', type=float, default=5e-4)
    parser.add_argument('--train_epochs', type=int, default=10)
    parser.add_argument('--patience', type=int, default=3)
    parser.add_argument('--lradj', type=str, default='type1')
    parser.add_argument('--loss', type=str, default='MSE')

    parser.add_argument('--use_gpu', type=str2bool,
                        default=torch.cuda.is_available())
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--use_multi_gpu', type=str2bool, default=False)
    parser.add_argument('--devices', type=str, default='0')
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--checkpoints', type=str,
                        default='./outputs/checkpoints/')

    return parser


def setup_environment(args: argparse.Namespace) -> argparse.Namespace:
    args.task_name = 'long_term_forecast'

    args.use_gpu = bool(args.use_gpu) and torch.cuda.is_available()
    if args.use_multi_gpu:
        args.devices = args.devices.replace(' ', '')
        device_ids = [int(id_) for id_ in args.devices.split(',') if id_]
        args.device_ids = device_ids
        if device_ids:
            args.gpu = device_ids[0]
    else:
        args.device_ids = [args.gpu]

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.use_gpu:
            torch.cuda.manual_seed_all(args.seed)

    Path(args.checkpoints).mkdir(parents=True, exist_ok=True)
    return args


def format_setting(args: argparse.Namespace, iteration: int) -> str:
    return (
        f"{args.model_id}_{args.model}_{args.data}_sl{args.seq_len}_"
        f"pl{args.pred_len}_dm{args.d_model}_el{args.e_layers}_{args.des}_{iteration}"
    )


def main() -> None:
    parser = create_parser()
    args = parser.parse_args()
    args = setup_environment(args)

    experiment_class = ExpMain

    if args.is_training:
        total_start = time.time()
        for itr in range(args.itr):
            setting = format_setting(args, itr)
            exp = experiment_class(args)
            print(f"\n>>> Training run {itr + 1}/{args.itr}: {setting}")
            exp.train(setting)

            print(f"\n>>> Testing run {itr + 1}/{args.itr}: {setting}")
            avg_epoch_time = getattr(exp, 'last_avg_epoch_time', None)
            exp.test(setting, training_time=avg_epoch_time)

            if args.use_gpu:
                torch.cuda.empty_cache()
        total_time = time.time() - total_start
        print(f"\nTotal wall-clock time: {total_time:.2f}s")
    else:
        setting = format_setting(args, 0)
        exp = experiment_class(args)
        print(f"\n>>> Testing only: {setting}")
        exp.test(setting, test=1)
        if args.use_gpu:
            torch.cuda.empty_cache()

    print('Experiment finished.')


if __name__ == '__main__':
    main()
