# -*- coding: utf-8 -*-

import argparse
import os
import math
import warnings
import numpy as np
from datetime import datetime

import torch
import wandb

from optim.generalWay import *
from dataloader.tempData import *

warnings.filterwarnings("ignore")


def parse_option():
    """Parse command line options used for training.

    The function exposes most hyperparameters and dataset choices. Supported
    datasets include ``CricketX``, ``UWave``, ``InsectWing``, ``SelfReg``,
    ``NATOPS`` and ``Heartbeat``.

    Returns:
        argparse.Namespace: Parsed arguments ready for ``main``.
    """
    parser = argparse.ArgumentParser('argument for training')
    parser.add_argument('--save_freq', type=int, default=200, help='save frequency')
    parser.add_argument('--batch_size', type=int, default=32, help='batch_size')
    parser.add_argument('--K', type=int, default=4, help='Number of augmentation for each sample')
    parser.add_argument('--alpha', type=float, default=0.5, help='Past-future split point')
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--feature_size', type=int, default=64, help='feature_size')
    parser.add_argument('--num_workers', type=int, default=0, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1000, help='number of training epochs')
    parser.add_argument('--patience', type=int, default=100, help='training patience')
    parser.add_argument('--aug_type', type=str, default='none', help='Augmentation type')

    parser.add_argument('--class_type', type=str, default='3C', help='Classification type')
    parser.add_argument('--gpu', type=str, default='0', help='gpu id')

    parser.add_argument('--learning_rate', type=float, default=2e-3, help='learning rate')
    parser.add_argument('--weight_rampup', type=int, default=30, help='weight rampup')

    parser.add_argument('--dataset_name', type=str, default='CricketX', help='dataset')
    parser.add_argument('--nb_class', type=int, default=3, help='class number')
    parser.add_argument('--ucr_path', type=str, default='./datasets/', help='Data root for dataset.')
    parser.add_argument('--ckpt_dir', type=str, default='./ckpt/', help='Data path for checkpoint.')
    parser.add_argument('--backbone', type=str, default='Our')
    parser.add_argument('--model_name', type=str, default='SemiTeacher',
                        choices=['SupCE', 'SemiTime', 'SemiTeacher', 'Teacher', 'PI', 'MTL', 'TapNet'],
                        help='choose method')
    parser.add_argument('--label_ratio', type=float, default=0.1, help='label ratio')
    parser.add_argument('--save_dir', type=str, default="gradient", help='save visualization')
    parser.add_argument('--usp_weight', type=float, default=1, help='usp weight')
    parser.add_argument('--ema_decay', type=float, default=0.99, help='weight')
    parser.add_argument('--model_select', type=str, default='TCN', help='Training model type')
    parser.add_argument('--nhid', type=int, default=128, help='feature_size')
    parser.add_argument('--levels', type=int, default=8, help='feature_size')
    parser.add_argument('--ksize', type=int, default=3, help='kernel size')
    parser.add_argument('--dropout', type=float, default=0.05, help='dropout applied to layers (default: 0.05)')

    parser.add_argument('--lip', type=bool, default=True, help='Whether to limit the Lipschitz constant')
    parser.add_argument('--saliency', type=bool, default=False, help='Whether to use series saliency')
    parser.add_argument('--lambda_lp', type=float, default=1, help='lipschitz weight')
    parser.add_argument('--L', type=int, default=0, help='lipschitz constant')
    parser.add_argument('--iter', type=int, default=50, help='iteration')

    # CUDA settings
    parser.add_argument('--no-cuda', default=False, help='Disables CUDA training.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')

    # Regularization / stopping
    parser.add_argument('--wd', type=float, default=1e-3,
                        help='Weight decay (L2 loss on parameters). default: 5e-3')
    parser.add_argument('--stop_thres', type=float, default=1e-9,
                        help='Stop threshold for training error differences. Default:1e-9')

    # Network architecture
    parser.add_argument('--use_cnn', type=bool, default=True,
                        help='whether to use CNN for feature extraction. Default:False')
    parser.add_argument('--use_lstm', type=bool, default=True,
                        help='whether to use LSTM for feature extraction. Default:False')
    parser.add_argument('--use_rp', type=bool, default=False,
                        help='Whether to use random projection')
    parser.add_argument('--rp_params', type=str, default='-1,3',
                        help='Parameters for random projection: number of random projections, sub-dimension for each')
    parser.add_argument('--use_metric', action='store_true', default=False,
                        help='whether to use metric learning for class representation. Default:False')
    parser.add_argument('--metric_param', type=float, default=2e-3,
                        help='Metric parameter for prototype distances between classes.')
    parser.add_argument('--filters', type=str, default="256,256,128",
                        help='filters used for CNN layers. Default:256,256,128')
    parser.add_argument('--kernels', type=str, default="8,5,3",
                        help='kernels used for the CNN layers. Default:8,5,3')
    parser.add_argument('--dilation', type=int, default=1,
                        help='the dilation used for the first CNN layer. If -1 => automatic.')
    parser.add_argument('--layers', type=str, default="500,300",
                        help='layer settings of mapping function, e.g.: 500,300')
    parser.add_argument('--lstm_dim', type=int, default=128, help='Dimension of LSTM Embedding.')

    # Additional hyperparameters for transformations/regularization
    parser.add_argument('--tv_coeff', type=float, default=2, help='Coefficient of TV')
    parser.add_argument('--tv_beta', type=float, default=1, help='TV beta value')
    parser.add_argument('--l1_coeff', type=float, default=1e-2, help='L1 regularization')
    parser.add_argument('--factor', type=int, default=7, help='Factor to upsampling')
    parser.add_argument('--img_path', type=str, default='examples/fl.jpg', help='image path')
    parser.add_argument('--lambda_Rm', type=float, default=1, help='lambda_Rm')
    parser.add_argument('--lambda_Rs', type=float, default=1, help='lambda_Rs')
    parser.add_argument('--lambda_reg', type=float, default=1e-5, help='lambda_reg')
    parser.add_argument('--beta_reg', type=float, default=1e-3, help='beta_reg')
    parser.add_argument('--diag_alpha', type=float, default=1, help='diag_alpha for Sobolev')
    parser.add_argument('--sobolev_norm_type', type=str, default='h-2', help='(h-1 or h-2)')

    parser.add_argument('--Saliency_dir', type=str, default='results/saliency/', help='saliency path')
    parser.add_argument('--Mask_dir', type=str, default='results/mask/', help='mask path')
    parser.add_argument('--use_flag', type=str2bool, default=True, help='user-defined boolean')
    parser.add_argument('--visualize_3d', type=bool, default=True, help='user-defined boolean flag')
    parser.add_argument('--visualize_ts', type=bool, default=True, help='user-defined boolean flag')
    parser.add_argument('--visualize_tsne', type=bool, default=True, help='user-defined boolean flag')

    opt = parser.parse_args()
    return opt

def str2bool(v):
    if isinstance(v, bool):
        return v
    v_lower = v.lower()
    if v_lower in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v_lower in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def main():
    # 1) Parse configuration
    opt = parse_option()
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu

    # 2) Pre-process some parameters
    opt.sparse = True
    opt.layers = [int(l) for l in opt.layers.split(",")]
    opt.kernels = [int(l) for l in opt.kernels.split(",")]
    opt.filters = [int(l) for l in opt.filters.split(",")]
    opt.rp_params = [float(l) for l in opt.rp_params.split(",")]

    # Check if Sobolev regularization is enabled
    if opt.lambda_reg == 0:
        sub_exp = 'baseline'
        sobolev_dir = 'no-sobolev'
    else:
        sub_exp = 'exp-cls'
        sobolev_dir = opt.sobolev_norm_type

    # Set random seed list and the number of runs
    Seeds = [2001]
    Runs = range(0, 1, 1)

    model_paras = f"label{opt.label_ratio}"
    log_dir = os.path.join(
        './results',
        f"useFlag_{opt.use_flag}",
        f"diagAlpha_{opt.diag_alpha}",
        sub_exp,
        opt.dataset_name,
        opt.model_name,
        sobolev_dir,
        model_paras
    )
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Prompt batch size
    print("Batch size:", opt.batch_size)

    # 3) Data augmentation (example only)
    aug1 = ['jitter', 'cutout']
    aug2 = ['G0', 'time_warp']
    if aug1 == aug2:
        opt.aug_type = [aug1]
    elif isinstance(aug1, list):
        opt.aug_type = aug1 + aug2
    else:
        opt.aug_type = [aug1, aug2]

    ACCs_seed = {}
    MAX_EPOCHs_seed = {}

    for seed in Seeds:
        np.random.seed(seed)
        torch.manual_seed(seed)
        opt.seed = seed
        time_str = datetime.now().strftime("%Y%m%d_%H%M%S")

        ckpt_dir = os.path.join(
            './ckpt',
            f"useFlag_{opt.use_flag}",
            f"diagAlpha_{opt.diag_alpha}",
            sub_exp,
            opt.model_name,
            opt.dataset_name,
            sobolev_dir,
            model_paras,
            f"seed_{seed}_{time_str}"
        )
        print("path: {}".format(ckpt_dir))
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        opt.ckpt_dir = ckpt_dir  # Assign back to opt for subsequent use

        print(f"[INFO] Dataset={opt.dataset_name}, Seed={seed}, use_flag={opt.use_flag}, diag_alpha={opt.diag_alpha}, "
              f"lambda_reg={opt.lambda_reg}, sobolev={sobolev_dir}")

        # 4.1) Data loading
        if opt.dataset_name in [
            "CricketX", "UWaveGestureLibraryAll", "InsectWingbeatSound",
            "EpilepticSeizure", "MFPT", "XJTU"
        ]:
            x_train, y_train, x_val, y_val, x_test, y_test, opt.nb_class, _ = load_ucr2018(opt.ucr_path, opt.dataset_name)
        elif opt.dataset_name in ["Heartbeat", "NATOPS", "SelfRegulationSCP2"]:
            x_train, y_train, x_val, y_val, x_test, y_test, opt.nb_class, idx = load_multi_ts(opt.ucr_path, opt.dataset_name)
        else:
            raise ValueError(f"Dataset '{opt.dataset_name}' is not recognized or not supported here.")

        # Adjust rp_params
        if opt.rp_params[0] < 0:
            dim = x_train.shape[2]
            opt.rp_params = [3, math.floor(dim / (3 / 2))]
        else:
            dim = x_train.shape[1]
            opt.rp_params[1] = math.floor(dim)
        opt.rp_params = [int(l) for l in opt.rp_params]

        print("[INFO] rp_params:", opt.rp_params, "| beta_reg:", opt.beta_reg)

        # 4.2) Log file for current seed
        seed_summary_file = os.path.join(ckpt_dir, f"seed_{seed}_summary.txt")
        with open(seed_summary_file, 'w') as sf:
            sf.write(f"===== Seed: {seed} Summary =====\n")
            sf.write(f"Dataset: {opt.dataset_name}\n")
            sf.write(f"Model: {opt.model_name}\n")
            sf.write(f"Augmentation: {opt.aug_type}\n")
            sf.write(f"Label Ratio: {opt.label_ratio}\n")
            sf.write(f"lambda_reg: {opt.lambda_reg}\n")
            sf.write(f"sobolev_norm_type: {opt.sobolev_norm_type}\n")
            sf.write(f"use_flag: {opt.use_flag}\n")
            sf.write(f"diag_alpha: {opt.diag_alpha}\n")
            sf.write(f"Directory TimeStamp: {time_str}\n\n")

        # Track results for each run in this seed
        ACCs_run = {}
        MAX_EPOCHs_run = {}

        # 4.3) Multiple runs for the same seed
        for run in Runs:
            run_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")

            # ---- Remove non-serializable fields ----
            config_dict = dict(vars(opt))
            if "wb" in config_dict:
                del config_dict["wb"]

            wandb_run_name = f"useFlag_{opt.use_flag}-alpha_{opt.diag_alpha}-run_{run}-seed_{seed}"
            wandb.init(
                project=f"{opt.dataset_name}_Project_new",
                group=f"seed_{seed}",
                name=wandb_run_name,
                config=config_dict,
                mode="disabled"
            )

            opt.wb = wandb

            run_log_file = os.path.join(ckpt_dir, f"run_{run_time_str}_seed_{seed}_run_{run}.txt")
            with open(run_log_file, 'w') as rf:
                rf.write("===== Run Info =====\n")
                rf.write(f"TimeStamp: {run_time_str}\n")
                rf.write(f"Seed: {seed}\n")
                rf.write(f"Run: {run}\n")
                rf.write(f"Dataset: {opt.dataset_name}\n")
                rf.write(f"Model: {opt.model_name}\n")
                rf.write(f"Sobolev: {opt.sobolev_norm_type}\n")
                rf.write(f"use_flag: {opt.use_flag}\n")
                rf.write(f"diag_alpha: {opt.diag_alpha}\n\n")

            if 'SemiTeacher' in opt.model_name:
                acc_test, acc_unlabel, epoch_max = train_SemiMean(x_train, y_train, x_val, y_val, x_test, y_test, opt)

            # 4.3.4) Output and record
            print(f"[RUN INFO] Seed={seed}, Run={run}, use_flag={opt.use_flag}, diag_alpha={opt.diag_alpha}, "
                  f"Dataset={opt.dataset_name}, #Train={x_train.shape[0]}, #Test={x_test.shape[0]}, "
                  f"Classes={opt.nb_class}, lambda_reg={opt.lambda_reg}, sobolev={sobolev_dir}, "
                  f"Acc_Test={acc_test:.2f}, Acc_Unlabel={acc_unlabel:.2f}, EpochMax={epoch_max}")

            with open(run_log_file, 'a') as rf:
                rf.write(f"Final Test Accuracy: {acc_test:.4f}\n")
                rf.write(f"Final Unlabeled Accuracy: {acc_unlabel:.4f}\n")
                rf.write(f"Best Epoch: {epoch_max}\n")

            wandb.finish()
            ACCs_run[run] = acc_test
            MAX_EPOCHs_run[run] = epoch_max

        # Summary for this seed
        seed_acc_mean = round(np.mean(list(ACCs_run.values())), 2)
        seed_acc_std = round(np.std(list(ACCs_run.values())), 2)
        seed_epoch_max = np.max(list(MAX_EPOCHs_run.values()))

        ACCs_seed[seed] = seed_acc_mean
        MAX_EPOCHs_seed[seed] = seed_epoch_max

        with open(seed_summary_file, 'a') as sf:
            sf.write("\n===== Seed Results =====\n")
            sf.write(f"Runs: {list(Runs)}\n")
            sf.write(f"Run Accuracies: {ACCs_run}\n")
            sf.write(f"Mean Accuracy: {seed_acc_mean}\n")
            sf.write(f"Std Accuracy : {seed_acc_std}\n")
            sf.write(f"Max Epoch (best among runs): {seed_epoch_max}\n")
            sf.write("====================================\n\n")

        print(f"[SEED SUMMARY] Seed={seed}, use_flag={opt.use_flag}, diag_alpha={opt.diag_alpha}, "
              f"MeanAcc={seed_acc_mean}, StdAcc={seed_acc_std}, MaxEpoch={seed_epoch_max}")

    # 5) Final summary across all seeds
    ACCs_seed_values = list(ACCs_seed.values())
    ACCs_seed_mean = round(np.mean(ACCs_seed_values), 2)
    ACCs_seed_std = round(np.std(ACCs_seed_values), 2)
    MAX_EPOCHs_seed_max = np.max(list(MAX_EPOCHs_seed.values()))

    # Summary file
    final_summary_file = os.path.join(log_dir, "final_summary.txt")
    with open(final_summary_file, 'w') as ff:
        ff.write("===== Overall Summary across all Seeds =====\n")
        ff.write(f"Dataset: {opt.dataset_name}\n")
        ff.write(f"Model: {opt.model_name}\n")
        ff.write(f"Label Ratio: {opt.label_ratio}\n")
        ff.write(f"lambda_reg: {opt.lambda_reg}\n")
        ff.write(f"sobolev_norm_type: {opt.sobolev_norm_type}\n")
        ff.write(f"use_flag: {opt.use_flag}\n")
        ff.write(f"diag_alpha: {opt.diag_alpha}\n")
        ff.write(f"Seeds: {Seeds}\n")
        ff.write(f"Mean Accuracy across seeds: {ACCs_seed_mean}\n")
        ff.write(f"Std Accuracy across seeds : {ACCs_seed_std}\n")
        ff.write(f"Max Epoch across seeds    : {MAX_EPOCHs_seed_max}\n")
        ff.write("============================================\n")

    print(f"\n[FINAL SUMMARY] use_flag={opt.use_flag}, diag_alpha={opt.diag_alpha}, "
          f"Dataset={opt.dataset_name}, Model={opt.model_name}, sobolev={sobolev_dir}, "
          f"lambda_reg={opt.lambda_reg}, MeanAccAcrossSeeds={ACCs_seed_mean}, "
          f"StdAccAcrossSeeds={ACCs_seed_std}, MaxEpochAcrossSeeds={MAX_EPOCHs_seed_max}")


if __name__ == "__main__":
    main()
