import numpy as np
import argparse
import time
import datetime
import data_load as datautils
from utils import init_dl_program,dict2class
from ProSAR import AugProtoCL # Assuming ProSAR.py contains AugProtoCL
# from AutoTCL_CoST1 import AugProtoCL1
from models.augclass import *
# from config import *
import os
os.environ['CUDA_VISIBLE_DEVICES']='2' # Set CUDA device if needed

# --- Script start time for total duration calculation ---
t_script_start = time.time()

parser = argparse.ArgumentParser()

# --- Dataset and Experiment Setup ---
# parser.add_argument('--dataset',type=str, default='ETTh1', help='The dataset name')
# parser.add_argument('--load_default',type=bool, default=False, help='load default setting for dataset')
# parser.add_argument('--archive', type=str, default='forecast_csv_univar', help='forecast_csv_univar or forecast_csv -->univar or multivar')
# --- Archived UCR/UEA specific arguments (example) ---
parser.add_argument('--dataset',type=str, default='AtrialFibrillation', help='Dataset name (e.g., for UCR: Chinatown, electricity)')
parser.add_argument('--load_default',type=bool, default=True, help='Load default settings for dataset') # If False, use provided arguments
parser.add_argument('--archive', type=str, default='UEA', help='Archive type: UCR, UEA, forecast_csv_univar, forecast_csv') # Default UCR
# parser.add_argument('--backbone',type=str, default='cost', help='backbone type (e.g., cost, ts)')
parser.add_argument('--backbone',type=str, default='ts', help='backbone')


parser.add_argument('--gpu', type=int, default=0, help='The gpu no. used for training and inference')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--max_threads', type=int, default=None, help='Max threads (None for no specific limit)')
parser.add_argument('--eval', type=bool, default=True, help='do eval')

# --- Core Model and Training Hyperparameters ---
parser.add_argument('--batch_size', type=int, default=16, help='The batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='The learning rate for the main model')
parser.add_argument('--input_dims', type=int, default=0, help='Input dimension (features), usually inferred from data')
parser.add_argument('--output_dims', type=int, default=320, help='The representation dimension from the encoder')
parser.add_argument('--encoder_hidden_dims', type=int, default=128, help='Hidden dimension for the main encoder')
parser.add_argument('--encoder_depth', type=int, default=3, help='Depth of the main encoder')
parser.add_argument('--max_train_length', type=int, default=256, help='The max training sequence length')
parser.add_argument('--iters', type=int, default=None, help='The training iters (alternative to epochs)')
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
parser.add_argument('--kernels', nargs='+', type=int, default=[1, 2, 4, 8, 16, 32, 64, 128], help='Kernels for CoSTEncoder')


# --- AutoTCL Specific Arguments (Legacy/Optional) ---
# These parameters might be legacy if ProSAR has a different augmentation mechanism
parser.add_argument('--meta_lr', type=float, default=0.012, help='(AutoTCL) The augmentation learning rate')
parser.add_argument('--mask_mode', type=str, default='mask_last', help='(AutoTCL) Mask mode for embedding network')
parser.add_argument('--augmask_mode', type=str, default='mask_last', help='(AutoTCL) Mask mode for augmentation network')
parser.add_argument('--aug_dim', type=int, default=16, help='(AutoTCL) The hidden dimension for argumentation network')
parser.add_argument('--aug_depth', type=int, default=1, help='(AutoTCL) Depths of argumentation network')
parser.add_argument('--bias_init', type=float, default=0.0, help='(AutoTCL) Bias initialization for augmentation network')
parser.add_argument('--gamma_zeta', type=float, default=0.005, help='(AutoTCL) Gamma_zeta for augmentation regularization')
parser.add_argument('--hard_mask', type=str, default='True', help='(AutoTCL) Whether to use hard mask (True/False)') # Also in ProSAR bool conversion
parser.add_argument('--gumbel_bias', type=float, default=0.4, help='(AutoTCL) Gumbel bias for mask sampling')
parser.add_argument('--ratio_step', type=int, default=1, help='(AutoTCL) Ratio step M for mask sampling')
parser.add_argument('--reg_weight', type=float, default=0.2, help='(AutoTCL) The weight of H(x) regularization')
parser.add_argument('--regular_weight', type=float, default=0.002, help='(AutoTCL) The weight of Regularization term')
parser.add_argument('--dropout', type=float, default=0.1, help='(AutoTCL) Dropout of embedding network')
parser.add_argument('--augdropout', type=float, default=0.1, help='(AutoTCL) Dropout of argumentation network')

# These might overlap with core model params, ensure they are distinct if used for AutoTCL parts
# parser.add_argument('--repr_dims', type=int, default=320, help='The representation dimension (overlaps output_dims)')
# parser.add_argument('--hidden_dims', type=int, default=64, help='The hidden dimension for embedding (overlaps encoder_hidden_dims)')
# parser.add_argument('--depth', type=int, default=10, help='Depths of embedding network (overlaps encoder_depth)')


# --- ProSAR: DTW-ProtoCL Specific Arguments ---
parser.add_argument('--num_prototypes', type=int, default=8, help='Number of time-domain prototypes')
parser.add_argument('--proto_len', type=int, default=128, help='Length of each time-domain prototype')
parser.add_argument('--learnable_prototypes', type=str, default='False', help='Whether time-domain prototypes are learned via gradient (True/False)')
parser.add_argument('--use_prototype_decoder', type=str, default='True', help='Use a decoder to get time-domain prototypes from latent centroids (True/False)')
parser.add_argument('--prototype_momentum', type=float, default=0.7, help='Momentum for EMA update of prototypes (if not learnable and use_decoder)')
parser.add_argument('--prototype_update_freq', type=int, default=1, help='Frequency (in iterations) to update prototypes')
parser.add_argument('--update_prototypes_at_start', type=str, default='False', help='Update prototypes once at the very beginning of training (True/False)')
parser.add_argument('--max_samples_for_clustering', type=int, default=5000, help='Max samples to use for each prototype update clustering step')

parser.add_argument('--clustering_layers', type=int, default=3, help='Number of layers for hierarchical clustering (passed to func as layers+1 for FINCH-like algos)')
parser.add_argument('--clustering_target_layer', type=int, default=0, help='Which layer of centroids to use from hierarchical clustering (0-indexed)')
parser.add_argument('--clustering_distance', type=str, default='euclidean', help='Distance metric for clustering')
parser.add_argument('--ann_threshold', type=str, default=40000, help='Weight for prototype-wise part of L_inter')


# DTW Augmenter (e.g., AugSem in ProSAR) parameters
parser.add_argument('--n_fft', type=int, default=16, help='n_fft for STFT in dtw_augmenter (for frequency domain augmentation)')
parser.add_argument('--hop_length', type=int, default=16, help='hop_length for STFT in dtw_augmenter')
parser.add_argument('--win_length', type=int, default=16, help='win_length for STFT in dtw_augmenter')
##process
parser.add_argument('--dtw_gamma1', type=float, default=0.1, help='Gamma for SoftDTW in augmenter')
parser.add_argument('--dtw_gamma2', type=float, default=0.1, help='Gamma for SoftDTW in augmenter')
##identify(2,5,1)
parser.add_argument('--row_threshold', type=int, default=6, help='Gamma for SoftDTW in augmenter')
parser.add_argument('--column_threshold', type=int, default=6, help='Gamma for SoftDTW in augmenter')
parser.add_argument('--min_continuous_aligned_length', type=int, default=3, help='Gamma for SoftDTW in augmenter')
##nonalign
parser.add_argument('--mask_prob', type=float, default=0.3, help='Gamma for SoftDTW in augmenter')
parser.add_argument('--non_aligned_noise_std', type=float, default=0.2, help='Gamma for SoftDTW in augmenter')
##align
parser.add_argument('--phase_adj', type=float, default=1.0, help='Gamma for SoftDTW in augmenter')
parser.add_argument('--magnitude_noise_std', type=float, default=0.2, help='Gamma for SoftDTW in augmenter')



# ... (Add other YourDTWDistanceFeature specific parameters if any) ...

# Loss function weights and temperatures
parser.add_argument('--lambda_inst_orig_v1', type=float, default=0.0, help='Weight for NCE(original, dtw_view1) for L_intra')
parser.add_argument('--lambda_inst_orig_v2', type=float, default=1.0, help='Weight for NCE(original, dtw_view2) for L_intra')
parser.add_argument('--contrast_augmented_views', type=str, default='True', help='Whether to add NCE(dtw_view1, dtw_view2) for L_intra (True/False)')
parser.add_argument('--lambda_inst_v1_v2', type=float, default=0.0, help='Weight for NCE(dtw_view1, dtw_view2) for L_intra')
parser.add_argument('--use_cluster_loss', type=str, default='True', help='Whether to use cluster contrastive loss (L_inter) (True/False)')
parser.add_argument('--weight_loss_inst', type=float, default=0.5, help='Overall weight for instance contrastive losses (lambda_intra)')
parser.add_argument('--hcl_temporal_unit', type=int, default=0, help='Temporal unit for hierarchical_contrastive_loss (if contrastive_loss_type="pool")')
parser.add_argument('--temperature_cluster', type=float, default=0.1, help='Temperature for cluster contrastive loss (L_inter)')
parser.add_argument('--temperature_nce', type=float, default=0.5, help='Temperature for instance NCE loss (L_intra components)')
parser.add_argument('--local_weight', type=float, default=1.0, help='Weight of local contrastive loss component in L_intra (if contrastive_loss_type="nopool")')
parser.add_argument('--global_weight', type=float, default=1.0, help='Weight of global contrastive loss component in L_intra (if contrastive_loss_type="nopool")')
parser.add_argument('--weight_loss_clust', type=float, default=0.5, help='Overall weight for cluster contrastive loss (lambda_inter)')
parser.add_argument('--lambda_cluster_ins', type=float, default=1.0, help='Weight for instance-wise part of L_inter')
parser.add_argument('--lambda_proto', type=float, default=1.0, help='Weight for prototype-wise part of L_inter')



# Evaluation Settings
parser.add_argument('--eval_every_epoch', type=int, default=5, help='Evaluate every N epochs')
parser.add_argument('--eval_start_epoch', type=int, default=0, help='Start evaluation after N epochs')

# ProSAR Prototype specific arguments (multivariate, update sources)
parser.add_argument(
    '--prototype_multivariate',
    type=str, # Parsed as string, converted to bool in ProSAR.py
    default='False', # Default False, for univariate prototypes
    help='Whether prototypes are multivariate (have their own variable dimension) (True/False)'
)
parser.add_argument(
    '--w_decoder_source', type=float, default=0.5, # Default based on your ProSAR.py (was 0.0, then 0.5)
    help="Weight for decoder-based source in prototype update (0.0 to 1.0 for Latent-to-Time-Domain Decoding Consistency)."
)
parser.add_argument(
    '--w_direct_avg_feat_source', type=float, default=0.0,
    help="Weight for direct-average-of-feature-cluster-samples source (0.0 to 1.0)."
)
parser.add_argument(
    '--w_direct_ts_cluster_source', type=float, default=0.5, # Default based on your ProSAR.py (was 1.0, then 0.5)
    help="Weight for direct-time-series-cluster-centroids source (0.0 to 1.0 for Input-Space Grounding)."
)

# Parameters for direct time-series clustering (if w_direct_ts_cluster_source > 0)
parser.add_argument(
    '--ts_cluster_use_kmeans', type=str, default='True', # Use KMeans by default
    help="For 'direct_ts_cluster' source, use KMeans on (flattened) time series (True/False)."
)
parser.add_argument(
    '--ts_cluster_distance', type=str, default='euclidean', # Distance for KMeans
    help="Distance metric for direct time-series clustering (primarily for reference if other algos are swapped in)."
)
parser.add_argument(
    '--clustering_verbose', type=str, default='False', # Verbosity for clustering
    help="Print verbose clustering information (True/False)"
)

parser.add_argument(
    '--align_verbose', type=str, default='True', # Verbosity for clustering
    help="Print verbose clustering information (True/False)"
)

# Contrastive loss type for instance-level contrast ($L_{intra}$)
parser.add_argument(
    '--contrastive_loss_type',
    type=str,
    default='nopool',  # Default 'nopool' (e.g. AutoTCL/ProSAR paper's L_intra)
    choices=['pool', 'nopool'],
    help="Type of instance contrastive loss: 'pool' (TS2Vec-like hierarchical) or 'nopool' (InfoNCE + local_infoNCE)"
)
# parser.add_argument( # This local_weight is now part of the 'nopool' type above.
#     '--local_weight',
#     type=float,
#     default=0.1, # Original default, but ProSAR 'nopool' might use its own.
#     help="For 'nopool' type, weight of local_infoNCE term"
# )
parser.add_argument(
    '--hcl_alpha',
    type=float,
    default=1.0, # Alpha for TS2Vec's hierarchical_contrastive_loss
    help="For 'pool' type (hierarchical_contrastive_loss), alpha parameter for weighting positive pairs"
)



parser.add_argument('--lr_scheduler_type', type=str, default='none',
                    help='Type of LR scheduler (e.g., "none", "cosine", "step", "multistep")')

# CosineAnnealingLR
parser.add_argument('--lr_cosine_t_max', type=int, default=50,
                    help='T_max for CosineAnnealingLR (often total epochs, defaults to args.epochs)')
parser.add_argument('--lr_cosine_eta_min', type=float, default=1e-4,
                    help='Minimum learning rate for CosineAnnealingLR')

# StepLR
parser.add_argument('--lr_step_size', type=int, default=30,
                    help='Step size for StepLR scheduler')
parser.add_argument('--lr_step_gamma', type=float, default=0.1,
                    help='Gamma for StepLR scheduler (multiplicative factor)')

# MultiStepLR
parser.add_argument('--lr_multistep_milestones', type=int, nargs='+', default=None,
                    help='Epoch milestones for MultiStepLR (e.g., 50 80 110). Defaults to 0.5 and 0.75 of total epochs.')
parser.add_argument('--lr_multistep_gamma', type=float, default=0.1,
                    help='Gamma for MultiStepLR scheduler')






raw_args = parser.parse_args() # Changed from paras to raw_args for clarity
if raw_args.load_default:
    # Assuming univar is determined from archive name or explicitly passed to merege_config
    is_univariate = 'univar' in raw_args.archive.lower() if raw_args.archive else False
    # merege_config should return an object or dict compatible with args
    merged_params_obj = merege_config(raw_args, raw_args.dataset, univar=is_univariate)
    if isinstance(merged_params_obj, dict):
        args = argparse.Namespace(**vars(raw_args)) # Start with raw_args
        for k, v in merged_params_obj.items(): # Override/add with merged params
            setattr(args, k, v)
    else: # If merged_params_obj is already a Namespace-like object
        args = merged_params_obj
else:
    args = raw_args
# args = dict2class(**params) # This was commented out, ensure 'args' is correctly formed.

device = init_dl_program(args.gpu, seed=args.seed, max_threads=args.max_threads ) # max_threads was None

# --- Data Loading ---
# Initialize variables for data and task properties
train_data = None
train_labels = None # For classification
test_data = None    # For classification
test_labels = None  # For classification
valid_dataset_tuple = None # Tuple to pass for validation/evaluation
task_type = '' # 'classification' or 'forecasting'
# For forecasting tasks
scaler = None
pred_lens = []
n_covariate_cols = 0

print(f"Loading dataset: {args.dataset}, Archive: {args.archive}")

# --- Load Data Based on Archive Type ---
if args.archive == 'UCR':
    task_type = 'classification'
    train_data, train_labels, test_data, test_labels = datautils.load_UCR(args.dataset)
    valid_dataset_tuple = (train_data, train_labels, test_data, test_labels)
elif args.archive == 'UEA':
    task_type = 'classification'
    train_data, train_labels, test_data, test_labels = datautils.load_UEA(args.dataset)
    valid_dataset_tuple = (train_data, train_labels, test_data, test_labels) # For fit function evaluation
elif args.dataset == "lora": # Special handling for "lora" dataset
    task_type = 'forecasting'
    is_univar_lora = 'univar' in args.archive.lower() # Check if univar for lora
    data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv_lora(univar=is_univar_lora)
    train_data = data[:, train_slice]
    valid_dataset_tuple = (data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols)
elif args.archive == 'forecast_csv':
    task_type = 'forecasting'
    data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset)
    train_data = data[:, train_slice]
    valid_dataset_tuple = (data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols)
elif args.archive == 'forecast_csv_univar':
    task_type = 'forecasting'
    data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset, univar=True)
    train_data = data[:, train_slice]
    valid_dataset_tuple = (data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols)
else:
    raise ValueError(f"Unknown archive type: {args.archive}")

if train_data is None:
    raise ValueError("Train data failed to load.")

print(f"Task type: {task_type}")
print(f"Train data shape: {train_data.shape}")


# Adjust batch_size if number of samples is too small
# Note: ProSAR.py's get_dataloader also has logic for this.
if train_data.ndim == 3: # (N, L, C)
    num_samples = train_data.shape[0]
elif train_data.ndim == 2 and task_type == 'forecasting': # (C, L_total)
    # This case is complex as actual N depends on how L_total is segmented.
    # Assuming datautils.load_forecast_csv might return (C, L_total)
    # If train_data is (1, L, C) for a single long series that gets segmented:
    if train_data.shape[0] == 1 and args.max_train_length > 0 : # A single series to be segmented
        num_segments = train_data.shape[1] // args.max_train_length
        num_samples = num_segments
    elif train_data.shape[0] > 1 : # (C, L_total) or (N_already_segmented, L, C_implicit=1 if 2D)
        num_samples = train_data.shape[0] # Treat first dim as N
    else:
        num_samples = 1 # Default fallback
else: # e.g. classification (N, L)
    num_samples = train_data.shape[0]


if train_data.shape[0] == 1:
    train_slice_number = int(train_data.shape[1] / args.max_train_length)
    if train_slice_number < args.batch_size:
        args.batch_size = train_slice_number
else:
    if train_data.shape[0] < args.batch_size:
        args.batch_size = train_data.shape[0]

# Infer input_dims from data
# train_data is expected to be (N, L, C) for model input, or (C, L_total) for forecasting before processing
if train_data.ndim == 3: # (N, L, C)
    args.input_dims = train_data.shape[-1]
elif train_data.ndim == 2 and task_type == 'forecasting': # (C, L_total)
    args.input_dims = train_data.shape[0] # C is the number of features
elif train_data.ndim == 2 and task_type == 'classification': # (N, L), implies univariate
    args.input_dims = 1
    # Reshape to (N, L, 1) for consistency if model expects 3D input
    train_data = np.expand_dims(train_data, axis=-1)
    if test_data is not None and test_data.ndim == 2:
        test_data = np.expand_dims(test_data, axis=-1)
        # Update valid_dataset_tuple if it was already created for classification
        if args.archive in ['UCR', 'UEA']:
             valid_dataset_tuple = (train_data, train_labels, test_data, test_labels)

print(f"Inferred input_dims (features): {args.input_dims}")


# Adjust prototype length if actual sequence length is smaller
# Sequence length L is the second dimension for (N,L,C) or (C,L_total)
current_sequence_length = train_data.shape[1]
# Actual segment length used in training depends on max_train_length
actual_segment_len = min(current_sequence_length, args.max_train_length) if args.max_train_length is not None and args.max_train_length > 0 else current_sequence_length

if actual_segment_len < args.proto_len:
    print(f"Warning: Actual segment length ({actual_segment_len}) is less than proto_len ({args.proto_len}). Adjusting proto_len to {actual_segment_len}.")
    args.proto_len = actual_segment_len
if args.proto_len <=0: # Ensure proto_len is positive
    args.proto_len = 16 # A small default
    print(f"Warning: proto_len became invalid, reset to {args.proto_len}.")


# --- Model Initialization and Training ---
t_training_prep_done = time.time()
print("Initializing ProSAR model...")

model = AugProtoCL(
    args=args, # Pass the processed args object
    device=device
)

print("Starting model training (fit)...")
# results will hold training logs or evaluation metrics from model.fit
results = model.fit( # Changed from res to results
     train_data_np=train_data, # Pass the numpy train data
     task_type = task_type,
     n_epochs=args.epochs,
     n_iters=args.iters,
     miverbose=True, # Assuming miverbose controls some print outputs
     valid_dataset = valid_dataset_tuple # Pass the validation data tuple
    )

# --- Print Final Results (based on task_type) ---
if task_type == 'classification':
    # AugProtoCL.fit for classification returns: loss_log, acc_log, vx_log, vy_log
    loss_history, acc_history, vx_history, vy_history = results

    final_loss = loss_history[-1] if loss_history and not (
                isinstance(loss_history[-1], float) and np.isnan(loss_history[-1])) else 'N/A'
    # Ensure acc_history is not empty and its last element is not NaN
    final_acc = acc_history[-1] if acc_history and len(acc_history) > 0 and not (
                isinstance(acc_history[-1], float) and np.isnan(acc_history[-1])) else 'N/A'

    print(f"Results for '{args.dataset}' (Classification):")
    if final_loss != 'N/A': print(f"  Final Training Loss: {final_loss:.4f}")
    if final_acc != 'N/A': print(f"  Final Test Accuracy: {final_acc:.4f}")
    # Potentially print full accuracy history for debugging
    # if isinstance(final_acc, float) and not np.isnan(final_acc): print(f"  (Full Accuracy History: {acc_history})")

elif task_type == 'forecasting':
    mses, maes = results
    if mses and maes and len(mses) > 0 and len(maes) > 0 and \
            not (isinstance(mses[-1], float) and np.isnan(mses[-1])) and \
            not (isinstance(maes[-1], float) and np.isnan(maes[-1])):
        final_mse_info = 'MSE: %.5f, MAE: %.5f' % (mses[-1], maes[-1]) # Changed mi_info to final_mse_info
        print(f"Results for '{args.dataset}' (Forecasting): {final_mse_info}")
    else:
        print(f"Results for '{args.dataset}' (Forecasting): Final MSE/MAE not recorded or NaN.")

training_time_total_main = time.time() - t_training_prep_done # Time for model init + fit
print(f"\nTotal training and evaluation time: {datetime.timedelta(seconds=training_time_total_main)}")
total_script_duration = time.time() - t_script_start
print(f"Total script execution time: {datetime.timedelta(seconds=total_script_duration)}")
print("Process finished.")