import os
from datetime import datetime
import itertools

import numpy as np
from preprocess import add_cov_preprocess, form_dataset_config, preprocess_data
from run_training_mix import run_train_model_v3
import torch
from log import log

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
if __name__ == "__main__":
    import argparse

    def str_to_tuple(s):
        return tuple(map(int, s.split(",")))

    # Argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="benchmark")
    parser.add_argument("--EPOCHS", type=int, default=300)
    parser.add_argument("--EPOCHS_FINE_TUNE", type=int, default=300)
    parser.add_argument("--BATCH_SIZE", type=int, default=128)
    parser.add_argument("--LR", type=float, default=0.0001)
    parser.add_argument("--WD", type=float, default=0.0001)
    parser.add_argument("--DROP_RATE", type=float, default=0.7)
    parser.add_argument(
        "--MODEL_NAME", type=str, default="tmaeformer_guney"
    )  # tmaev2_guney,tmae_guney,tmae_cvt,
    parser.add_argument("--time_period_real", type=int, default=50)
    parser.add_argument("--REGENERATION_EPOCHS", type=int, default=100)
    parser.add_argument("--EPOCHS_FINE_TUNE_FOR_REG", type=int, default=500)
    parser.add_argument("--s1_emb_dim", type=int, default=32)
    parser.add_argument("--s1_emb_kernel", type=str_to_tuple, default=(9, 1))
    parser.add_argument("--s1_emb_stride", type=int, default=1)
    parser.add_argument("--s1_proj_kernel", type=str_to_tuple, default=(1, 20))
    parser.add_argument("--s1_kv_proj_stride", type=str_to_tuple, default=(1, 3))
    parser.add_argument("--s1_heads", type=int, default=4)
    parser.add_argument("--s1_depth", type=int, default=1)
    parser.add_argument("--tmae_hidden_bank", type=int, default=48)
    parser.add_argument("--tmae_time_kernel", type=int, default=35)
    parser.add_argument("--tmae_dropout", type=float, default=0.8)
    parser.add_argument("--tmae_regeneration_time", type=float, default=0.3)
    parser.add_argument("--DROP_RATE_init", type=float, default=0.1)
    parser.add_argument("--time2_kernel", type=int, default=10)
    parser.add_argument("--gpu_id", type=str, default="0")
    parser.add_argument("--L2_REG", type=float, default=1e-3)
    parser.add_argument("--run_version", type=str, default="v3")
    parser.add_argument("--moe_switch", type=str, default='time1_time2')


    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")

    parser.add_argument("--guney_origin", type=str2bool, default=True)
    parser.add_argument("--LR_tmae", type=float, default=1e-4)
    parser.add_argument("--LR_fine_tune", type=float, default=1e-4)
    parser.add_argument("--LR_tmae_fine_tune", type=float, default=1e-4)
    parser.add_argument("--fine_tune_regeneration", type=str2bool, default=True)
    parser.add_argument("--chunk_size", type=int, default=30)

    parser.add_argument("--augment_ratio", type=float, default=5)
    parser.add_argument(
        "--guney_origin_spatial_dropout_fine_tune", type=float, default=0.5
    )
    parser.add_argument(
        "--guney_origin_time1_dropout_fine_tune", type=float, default=0.5
    )
    parser.add_argument("--vit_patch_size_1", type=int, default=10)
    parser.add_argument("--depth", type=int, default=1)
    parser.add_argument("--chunk_ratio_regen", type=float, default=0)
    parser.add_argument("--channel_swap_ratio_regen", type=float, default=0)
    parser.add_argument("--time_shift_ratio_regen", type=float, default=0)
    parser.add_argument("--chunk_ratio_guney", type=float, default=0.2)
    parser.add_argument("--channel_swap_ratio_guney", type=float, default=0.2)
    parser.add_argument("--time_shift_ratio_guney", type=float, default=0.2)
    parser.add_argument("--is_add_cov_preprocess", type=str2bool, default=True)
    parser.add_argument("--verbose", type=str2bool, default=False)
    parser.add_argument("--is_cross_val", type=str2bool, default=True)
    parser.add_argument("--channels", type=int, default=9)
    parser.add_argument("--test_trials", type=int, default=1, choices=[1, 2], help="Number of trials to use for testing (1 or 2)")
    parser.add_argument("--moe_experts", type=int, default=4)
    parser.add_argument("--moe_top_k", type=int, default=1)
    parser.add_argument("--moe_version", type=str, default='v3')
    parser.add_argument("--add_aux_loss", type=str2bool, default=True)
    parser.add_argument("--n_spatial_filters", type=int, default=200)
    parser.add_argument("--n_time1_filters", type=int, default=120)
    parser.add_argument("--n_time2_filters", type=int, default=120)
    parser.add_argument("--guney_origin_last_time_dropout_fine_tune", type=float, default=0.97)

    args = parser.parse_args()

    from ml_collections import ConfigDict

    args = ConfigDict(vars(args))
    args.device = "cuda:" + args.gpu_id if torch.cuda.is_available() else "cpu"

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    exp_name = f"{args.MODEL_NAME}_BS{args.BATCH_SIZE}_EP{args.EPOCHS}_{timestamp}_gpu{args.gpu_id}_pid{os.getpid()}"

    base_exp_dir = ""
    base_exp_dir = os.path.join(base_exp_dir, args.dataset)
    if args.time_period_real != 50:
        base_exp_dir += "_" + str(args.time_period_real / 250)
    if args.test_trials == 2:
        base_exp_dir += "_test_2"
    if args.channels != 9:
        base_exp_dir += "_channels_" + str(args.channels)
    os.makedirs(base_exp_dir, exist_ok=True)
    exp_dir = os.path.join(base_exp_dir, exp_name)
    os.makedirs(exp_dir, exist_ok=True)
    args.exp_dir = exp_dir

    args.log_path = os.path.join(exp_dir, "log.txt")
    args.log_final_results_path = os.path.join(exp_dir, "log_final_results.txt")

    # Run data prep and training
    dataset_config = form_dataset_config(args.dataset, args)

    x, y = preprocess_data(
        args,
        duration=0.2 if "tmae" not in args.MODEL_NAME else args.tmae_regeneration_time,
    )
    dataset_config.total_len_time_period = x.shape[1]
    print(f"X shape: {x.shape}")  # subject x samples x channels x trials
    print(f"Y shape: {y.shape}")  # subject x samples x trials

    x = x.reshape(
        *x.shape[:4], dataset_config.target, dataset_config.trials
    )  # reshape to [subject, samples, channels, trials]
    y = y.reshape(
        -1, dataset_config.target, dataset_config.trials
    )  # reshape to [subject, samples, trials]
    print(f"Reshaped X shape: {x.shape}")  # subject x samples x channels x trials

    # for i in range(dataset_config.trials):
    if args.is_cross_val:
        test_combinations = list(itertools.combinations(range(dataset_config.trials), args.test_trials))
        print(f"All test combinations: {test_combinations}")
        
        # Randomly select 6 combinations if there are more than 6
        if len(test_combinations) > 6:
            import random
            test_combinations = random.sample(test_combinations, 6)
            print(f"Randomly selected 6 combinations: {test_combinations}")
        
        temp_trails = len(test_combinations)
    else:
        # Single trial mode - only process one trial
        temp_trails = 1
        test_combinations = [(0,)]  # Just use trial 0
    
    cross_val_result = np.zeros(shape=(temp_trails, 2))
    
    for i, test_indices in enumerate(test_combinations):
        if args.is_cross_val:
            x_test = x[:, :, :, :, :, test_indices]
            y_test = y[:, :, test_indices]
            train_indices = [j for j in range(dataset_config.trials) if j not in test_indices]
            x_train = x[:, :, :, :, :, train_indices]
            y_train = y[:, :, train_indices]
        else:
            # Single trial mode - test on trial 0, train on all other trials
            x_test = x[:, :, :, :, :, 0:args.test_trials]  # Test on trial 0
            y_test = y[:, :, 0:args.test_trials]
            x_train = x[:, :, :, :, :, args.test_trials:]  # Train on trials 1,2,3,4,5
            y_train = y[:, :, args.test_trials:]
            print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
            
        x_cross_val = np.concat([x_train, x_test], axis=5)
        if args.is_add_cov_preprocess:
            x_cross_val = add_cov_preprocess(x_cross_val, dataset_config)
            

        y_cross_val = np.concat([y_train, y_test], axis=2)

        if args.run_version == "v3":
            cross_val_mean, cross_val_std = run_train_model_v3(
                x_cross_val, y_cross_val, args, cv_index = i
            )  # with augmentation
            cross_val_result[i, 0] = cross_val_mean
            cross_val_result[i, 1] = cross_val_std
    log(
        f"Mean Accuracy after cross validation: {np.mean(cross_val_result[:,0])} +- {np.mean(cross_val_result[:,1])}",
        dataset_config.log_final_results_path,
    )
