from os.path import exists, join, dirname, abspath
from os import mkdir, scandir
import sys
import fnmatch
from typing import Tuple
import warnings
from omegaconf import OmegaConf
import logging
import itertools
import subprocess
import json
import time
from pathlib import Path

# Add parent directory to path to find teacher_student package
sys.path.insert(0, dirname(dirname(abspath(__file__))))

# Import train_teacher_student function (not module) to avoid circular import
from teacher_student.train_teacher_student import train_teacher_student

# LOCAL SWEEPERS_________________________________________________________________________

def independent_sweep(exp_dir):
    """
    This function scans for configuration files in a folder and runs accordingly
    to an independent sweep
    - func(config) is no sweeps
    - func(config(sweep)) for sweep in sweeps
    
    This function has to be used as decorator on the the function "func"
    func (function)-> None: callable function with input (config, save_dir)
    
    It also creates the folder structure for sweeps and passes it to func.
    Here how it looks like:
    
    ./experiments
    |__exp1
    |__exp2 <-- experiment folder “exp_dir”
        |__config.yaml <-- starting config file
        |__config_sweep.yaml <-- sweep instructions
        |__sweeps
            |__sweep_param1
            |__sweep_param2
            |__sweep_param3
                |__sweep_param3_val1
                |__sweep_param3_val2
                    |__config.yaml <-- sweep (modified) config file
                    |__results
                    
    Args:
        exp_dir (str): experiment folder
    """
    
    config, sweep_config = find_yml_files(dir=exp_dir)
    # set logging
    logger_info = logging.getLogger("logger_info")
    
    def decorator_repeat(func):
        def wrapper(*args, **kwargs):
        
            if sweep_config is not None:
                
                for cat in sweep_config:
                    for param in sweep_config[cat]:

                        sweep_param_path = join(exp_dir,"sweeps",f"sweep_{param}")

                        assert param in config[cat], AssertionError(f"Parameters {param} not in config file!")
                        for val in sweep_config[cat][param]:
                            
                            # logging
                            logger_info.info(f"Sweep at {cat}-{param}: {val}")
                            
                            # update the config file for the current sweep
                            config_sweep = config.copy()
                            config_sweep[cat][param] = val
                            # config_sweep = update_config(config_sweep)
                            
                            # update the saving dir
                            save_dir = join(sweep_param_path,f"sweep_{param}_{val}")

                            # execute function
                            func(*args, **kwargs, config=config_sweep, save_dir=save_dir)

                            # save config_sweep as yaml file
                            config_sweep_path = join(save_dir,"config.yaml")
                            OmegaConf.save(config_sweep, config_sweep_path)
            
            else:
                func(config,save_dir=exp_dir)   
        return wrapper
    return decorator_repeat





def combination_sweep(exp_dir, mode="combination"):
    """
    This function scans for configuration files in a folder and runs accordingly
    to the specified sweep mode:
    
    - "independent": One-at-a-time parameter sweep (like independent_sweep)
    - "combination": Sweep across all possible combinations of parameter values
    
    This function has to be used as decorator on the function "func"
    func (function)-> None: callable function with input (config, save_dir)
    
    It also creates the folder structure for sweeps and passes it to func.
    
    For independent mode, the folder structure looks like:
    
    ./experiments
    |__exp1
    |__exp2 <-- experiment folder "exp_dir"
        |__config.yaml <-- starting config file
        |__config_sweep.yaml <-- sweep instructions
        |__sweeps
            |__sweep_param1
            |__sweep_param2
            |__sweep_param3
                |__sweep_param3_val1
                |__sweep_param3_val2
                    |__config.yaml <-- sweep (modified) config file
                    |__results
    
    For combination mode, the folder structure looks like:
    
    ./experiments
    |__exp1
    |__exp2 <-- experiment folder "exp_dir"
        |__config.yaml <-- starting config file
        |__config_sweep.yaml <-- sweep instructions
        |__combinations
            |__combo_param1_val1_param2_val1
            |__combo_param1_val1_param2_val2
                |__config.yaml <-- sweep (modified) config file
                |__results
                    
    Args:
        exp_dir (str): experiment folder
        mode (str): sweep mode, either "independent" or "combination"
    """
    
    config, sweep_config = find_yml_files(dir=exp_dir)
    config = update_config(config)
    
    # set logging
    logger_info = logging.getLogger("logger_info")
    
    def decorator_repeat(func):
        def wrapper(*args, **kwargs):
        
            if sweep_config is not None:
                if mode == "independent":
                    # Independent sweep logic (like independent_sweep)
                    for cat in sweep_config:
                        for param in sweep_config[cat]:

                            sweep_param_path = join(exp_dir, "sweeps", f"sweep_{param}")
                            if not exists(sweep_param_path):
                                mkdir(sweep_param_path)

                            assert param in config[cat], AssertionError(f"Parameters {param} not in config file!")
                            for val in sweep_config[cat][param]:
                                
                                # logging
                                logger_info.info(f"Sweep at {cat}-{param}: {val}")
                                
                                # update the config file for the current sweep
                                config_sweep = config.copy()
                                config_sweep[cat][param] = val
                                
                                # update the saving dir
                                save_dir = join(sweep_param_path, f"sweep_{param}_{val}")
                                if not exists(save_dir):
                                    mkdir(save_dir)

                                # execute function
                                func(*args, **kwargs, config=config_sweep, save_dir=save_dir)

                                # save config_sweep as yaml file
                                config_sweep_path = join(save_dir, "config.yaml")
                                OmegaConf.save(config_sweep, config_sweep_path)
                
                elif mode == "combination":
                    # Combination sweep logic
                    
                    # Extract all parameters and their values
                    param_values = {}
                    param_categories = {}
                    
                    for cat in sweep_config:
                        for param in sweep_config[cat]:
                            assert param in config[cat], AssertionError(f"Parameters {param} not in config file!")
                            param_values[param] = sweep_config[cat][param]
                            param_categories[param] = cat
                    
                    # Generate all possible combinations
                    param_names = list(param_values.keys())
                    value_combinations = list(itertools.product(*(param_values[param] for param in param_names)))
                    
                    # Create combinations directory
                    combinations_path = join(exp_dir, "combinations")
                    if not exists(combinations_path):
                        mkdir(combinations_path)
                    
                    # Process each combination
                    for combo in value_combinations:
                        # Create a descriptive name for this combination
                        combo_name_parts = []
                        
                        # Update config with this combination
                        config_sweep = config.copy()
                        
                        # Log the combination
                        combo_log_parts = []
                        
                        for i, param in enumerate(param_names):
                            val = combo[i]
                            cat = param_categories[param]
                            
                            # Update config
                            config_sweep[cat][param] = val
                            
                            # Add to combo name
                            combo_name_parts.append(f"{param}_{val}")
                            
                            # Add to log
                            combo_log_parts.append(f"{cat}-{param}: {val}")
                        
                        
                        # Log the combination
                        logger_info.info(f"Combination sweep with " + ", ".join(combo_log_parts))
                        
                        # Create directory for this combination
                        combo_name = "combo_" + "_".join(combo_name_parts)
                        combo_dir = join(combinations_path, combo_name)
                        if not exists(combo_dir):
                            mkdir(combo_dir)
                        
                        # Execute function with this combination
                        func(*args, **kwargs, config=config_sweep, save_dir=combo_dir)
                        
                        # Save config
                        config_sweep_path = join(combo_dir, "config.yaml")
                        OmegaConf.save(config_sweep, config_sweep_path)
                
                else:
                    raise ValueError(f"Unknown sweep mode: {mode}. Use 'independent' or 'combination'.")
            
            else:
                func(config, save_dir=exp_dir)   
        return wrapper
    return decorator_repeat





# PARALLEL JOBS CLUSTER______________________________________________________________________________________________

def generate_slurm_job_array_script(exp_dir, home_exp_dir, experiment_id, combinations_file, max_concurrent_jobs=6, 
                                   walltime="5-00:00:00", gpu_mem="24g", mem_per_cpu="10g", scratch_path=None):
    """
    Generate SLURM job array script for parallel sweep execution with optional scratch support.
    
    Args:
        exp_dir (str): experiment directory (could be scratch or home)
        home_exp_dir (str): home experiment directory (for config files)
        experiment_id (str): experiment ID for naming
        combinations_file (str): path to JSON file containing all combinations
        max_concurrent_jobs (int): maximum concurrent jobs to run
        walltime (str): SLURM walltime limit
        gpu_mem (str): GPU memory requirement
        mem_per_cpu (str): CPU memory requirement
        scratch_path (str): scratch path if using scratch
    
    Returns:
        str: path to generated SLURM script
    """
    
    # Read combinations to determine array size
    with open(combinations_file, 'r') as f:
        combinations_data = json.load(f)
    
    total_jobs = len(combinations_data['combinations'])
    
    # Create SLURM script content with scratch support
    if scratch_path is not None:
        # Scratch-enabled script
        script_content = f"""#!/bin/bash
#SBATCH --job-name=sweep_{experiment_id}
#SBATCH --output={exp_dir}/slurm_logs/sweep_%A_%a.out
#SBATCH --error={exp_dir}/slurm_logs/sweep_%A_%a.err
#SBATCH --array=0-{total_jobs-1}%{max_concurrent_jobs}
#SBATCH --ntasks=1
#SBATCH --time={walltime}
#SBATCH --gpus=1
#SBATCH --mem-per-cpu={mem_per_cpu}
#SBATCH --gres=gpumem:{gpu_mem}

set -euo pipefail

echo "[$(date)] Job started on $(hostname)"
echo "Job ID: $SLURM_JOB_ID"
echo "Array Task ID: $SLURM_ARRAY_TASK_ID"

# ───────────────────────────────────────────────
# SCRATCH SETUP
# ───────────────────────────────────────────────
PROJ_HOME="$HOME/CF"
HOME_EXP="{home_exp_dir}"
SCRATCH_EXP="{exp_dir}"

echo "[$(date)] Home exp folder : $HOME_EXP"
echo "[$(date)] Scratch folder  : $SCRATCH_EXP"

# Copy combinations data to scratch if not already there (first job in array does this)
if [[ ! -f "$SCRATCH_EXP/combinations_data.json" ]]; then
    echo "[$(date)] Copying combinations data to scratch..."
    rsync -av "$HOME_EXP/" "$SCRATCH_EXP/" || echo "Warning: rsync failed, continuing..."
fi

# ───────────────────────────────────────────────
# ENVIRONMENT
# ───────────────────────────────────────────────
module load stack/2024-06
module load gcc/12.2.0
module load python_cuda/3.11.6

# Activate virtual environment
source "$PROJ_HOME/venv/bin/activate"

if [[ -z "${{VIRTUAL_ENV:-}}" ]]; then
    echo "[$(date)] Failed to activate Python environment!" >&2
    exit 1
fi
echo "[$(date)] Python env: $VIRTUAL_ENV"

# ───────────────────────────────────────────────
# RUN
# ───────────────────────────────────────────────
cd "$SCRATCH_EXP"

echo "[$(date)] Running combination $SLURM_ARRAY_TASK_ID..."
python -c "
import sys
sys.path.append('$PROJ_HOME/teacher_student/teacher_student')
from exp_control import run_single_combination
run_single_combination('$SCRATCH_EXP', '$SCRATCH_EXP/combinations_data.json', $SLURM_ARRAY_TASK_ID)
"

# ───────────────────────────────────────────────
# WRAP-UP
# ───────────────────────────────────────────────
deactivate
echo "[$(date)] Python environment deactivated"
echo "[$(date)] Job finished – results are in $SCRATCH_EXP"
"""
    else:
        # Home directory script (no scratch)
        script_content = f"""#!/bin/bash
#SBATCH --job-name=sweep_{experiment_id}
#SBATCH --output={exp_dir}/slurm_logs/sweep_%A_%a.out
#SBATCH --error={exp_dir}/slurm_logs/sweep_%A_%a.err
#SBATCH --array=0-{total_jobs-1}%{max_concurrent_jobs}
#SBATCH --ntasks=1
#SBATCH --time={walltime}
#SBATCH --gpus=1
#SBATCH --mem-per-cpu={mem_per_cpu}
#SBATCH --gres=gpumem:{gpu_mem}

echo "[$(date)] Job started on $(hostname)"
echo "Job ID: $SLURM_JOB_ID"
echo "Array Task ID: $SLURM_ARRAY_TASK_ID"

# Load modules
module load stack/2024-06
module load gcc/12.2.0
module load python_cuda/3.11.6

# Activate virtual environment
source "$HOME/CF/venv/bin/activate"

# Run the specific combination for this array task
cd "$HOME/CF/teacher_student/teacher_student"
python -c "
import sys
sys.path.append('$HOME/CF/teacher_student/teacher_student')
from exp_control import run_single_combination
run_single_combination('{exp_dir}', '{combinations_file}', $SLURM_ARRAY_TASK_ID)
"

deactivate
echo "[$(date)] Job finished"
"""
    
    # Create slurm_logs directory
    slurm_logs_dir = join(exp_dir, "slurm_logs")
    if not exists(slurm_logs_dir):
        mkdir(slurm_logs_dir)
    
    # Write script file
    script_path = join(exp_dir, "run_sweep_array.sh")
    with open(script_path, 'w') as f:
        f.write(script_content)
    
    return script_path


def run_single_combination(exp_dir, combinations_file, task_id):
    """
    Run a single parameter combination (called by SLURM job array).
    
    Args:
        exp_dir (str): experiment directory
        combinations_file (str): path to JSON file containing all combinations
        task_id (int): SLURM array task ID (index of combination to run)
    """
    
    # Load combinations data
    with open(combinations_file, 'r') as f:
        combinations_data = json.load(f)
    
    # Get the specific combination for this task
    combination = combinations_data['combinations'][task_id]
    config = OmegaConf.create(combinations_data['base_config'])
    
    # Apply parameter combination to config
    for param_name, param_value in combination['params'].items():
        cat = combination['categories'][param_name]
        config[cat][param_name] = param_value
    
    # Update config to resolve any placeholders after sweep parameters are applied
    config = update_config(config)
    
    # Create save directory
    save_dir = join(exp_dir, "combinations", combination['name'])
    if not exists(save_dir):
        Path(save_dir).mkdir(parents=True, exist_ok=True)
    
    # Get additional arguments from combinations data
    data_dir = combinations_data.get('data_dir', '')
    cluster = combinations_data.get('cluster', True)
    
    # Run the training
    train_teacher_student(config=config, save_dir=save_dir, data_dir=data_dir, cluster=cluster)
    
    # Save config
    config_path = join(save_dir, "config.yaml")
    OmegaConf.save(config, config_path)




def submit_parallel_sweep(exp_dir, home_exp_dir, data_dir, cluster, scratch_path=None, mode="combination", 
                         max_concurrent_jobs=6, walltime="5-00:00:00", gpu_mem="24g", mem_per_cpu="10g", submit_jobs=True):
    """
    Submit a parallel sweep using SLURM job arrays.
    
    This function replaces the decorator pattern for parallel sweeps with a direct approach.
    
    Args:
        exp_dir (str): experiment folder (could be scratch or home)
        home_exp_dir (str): home experiment directory (for config files)
        data_dir (str): data directory path
        cluster (bool): whether running on cluster
        scratch_path (str): scratch path if using scratch
        mode (str): sweep mode, either "independent" or "combination"
        max_concurrent_jobs (int): maximum concurrent SLURM jobs
        walltime (str): SLURM walltime limit
        gpu_mem (str): GPU memory requirement
        mem_per_cpu (str): CPU memory requirement
        submit_jobs (bool): whether to actually submit jobs or just prepare
    """
    
    # Read config files from home directory (where they actually exist)
    config_dir = home_exp_dir if scratch_path is not None else exp_dir
    config, sweep_config = find_yml_files(dir=config_dir)
    
    logger_info = logging.getLogger("logger_info")
    
    if sweep_config is None:
        print("No sweep configuration found. Running single experiment...")
        # Run single experiment
        train_teacher_student(config=config, save_dir=exp_dir, data_dir=data_dir, cluster=cluster)
        return
    
    if mode == "combination":
        # Extract experiment ID from exp_dir path
        experiment_id = exp_dir.split('/')[-1] if '/' in exp_dir else exp_dir.split('\\')[-1]
        
        # Extract all parameters and their values
        param_values = {}
        param_categories = {}
        
        for cat in sweep_config:
            for param in sweep_config[cat]:
                assert param in config[cat], AssertionError(f"Parameters {param} not in config file!")
                param_values[param] = sweep_config[cat][param]
                param_categories[param] = cat
        
        # Generate all possible combinations
        param_names = list(param_values.keys())
        value_combinations = list(itertools.product(*(param_values[param] for param in param_names)))
        
        # Create combinations directory
        combinations_path = join(exp_dir, "combinations")
        if not exists(combinations_path):
            mkdir(combinations_path)
        
        # Prepare combinations data for job array
        combinations_data = {
            'base_config': OmegaConf.to_container(config),
            'combinations': [],
            'data_dir': data_dir,  # Now properly passed!
            'cluster': cluster
        }
        
        # Process each combination
        for combo_idx, combo in enumerate(value_combinations):
            # Create combination info
            combo_params = {}
            combo_name_parts = []
            combo_log_parts = []
            
            for i, param in enumerate(param_names):
                val = combo[i]
                cat = param_categories[param]
                
                combo_params[param] = val
                combo_name_parts.append(f"{param}_{val}")
                combo_log_parts.append(f"{cat}-{param}: {val}")
            
            # Create directory name
            combo_name = "combo_" + "_".join(combo_name_parts)
            
            # Add to combinations data
            combinations_data['combinations'].append({
                'id': combo_idx,
                'params': combo_params,
                'categories': param_categories,
                'name': combo_name,
                'log_description': ", ".join(combo_log_parts)
            })
        
        # Save combinations data
        combinations_file = join(exp_dir, "combinations_data.json")
        with open(combinations_file, 'w') as f:
            json.dump(combinations_data, f, indent=2)
        
        logger_info.info(f"Generated {len(value_combinations)} parameter combinations")
        print(f"Generated {len(value_combinations)} parameter combinations")
        print(f"Data directory: {data_dir}")
        
        if submit_jobs:
            # Generate and submit SLURM job array
            script_path = generate_slurm_job_array_script(
                exp_dir, home_exp_dir, experiment_id, combinations_file, max_concurrent_jobs,
                walltime, gpu_mem, mem_per_cpu, scratch_path
            )
            
            logger_info.info(f"Generated SLURM script: {script_path}")
            
            # Submit job array
            try:
                result = subprocess.run(['sbatch', script_path], 
                                      capture_output=True, text=True, cwd=exp_dir)
                
                if result.returncode == 0:
                    job_id = result.stdout.strip().split()[-1]
                    logger_info.info(f"Submitted job array with ID: {job_id}")
                    
                    # Save job ID for monitoring
                    with open(join(exp_dir, "job_id.txt"), 'w') as f:
                        f.write(job_id)
                    
                    print(f"Job array submitted successfully!")
                    print(f"Job ID: {job_id}")
                    print(f"Total combinations: {len(value_combinations)}")
                    print(f"Max concurrent jobs: {max_concurrent_jobs}")
                    print(f"Monitor progress with: squeue -u $USER")
                    
                else:
                    logger_info.error(f"Failed to submit job: {result.stderr}")
                    print(f"Error submitting job: {result.stderr}")
                    
            except Exception as e:
                logger_info.error(f"Error submitting job: {str(e)}")
                print(f"Error submitting job: {str(e)}")
        
        else:
            print(f"Prepared {len(value_combinations)} combinations for parallel execution")
            print(f"Run script: {join(exp_dir, 'run_sweep_array.sh')}")
    
    elif mode == "independent":
        # For independent mode, fall back to sequential execution
        logger_info.warning("Independent mode not yet parallelized, falling back to sequential")
        print("Independent mode not yet parallelized, falling back to sequential execution")
        
        # Use the existing combination_sweep decorator for sequential execution
        @combination_sweep(exp_dir, mode="independent")
        def run_sweep(config, save_dir):
            train_teacher_student(config=config, save_dir=save_dir, data_dir=data_dir, cluster=cluster)
        
        run_sweep()
    
    else:
        raise ValueError(f"Unknown sweep mode: {mode}. Use 'independent' or 'combination'.")





# HELPERS________________________________________________________________________________________________


def find_yml_files(dir:str)-> Tuple[dict]:
    """
    Look for configuration and sweep file in a directory (dir)

    Args:
        dir (str): path to look for file, usually experiment folder

    Raises:
        FileNotFoundError: dir doesn't contain the config file
        Warning: dir doesn't contain the sweep filed

    Returns:
        Tuple[dict]: config_dict, sweep_dict
    """
    config_control_string = "config"
    sweep_control_string = "sweep"
    
    config, sweep_config = None, None
    
    with scandir(dir) as entries:
    
        for entry in entries:
            if entry.is_file():
                
                if fnmatch.fnmatch(entry.name,f"*{config_control_string}*.yaml"):
                    config = OmegaConf.load(entry.path)
                    
                if fnmatch.fnmatch(entry.name,f"*{sweep_control_string}*.yaml"):
                    sweep_config = OmegaConf.load(entry.path)
                    
    if config is None:
        raise FileNotFoundError("No configuration file found")
    
    if sweep_config is None:
        warnings.warn("No available sweep found")
    
    return config, sweep_config



def update_config(config: dict)->dict:
    """
    Updates the config file where placeholders are set 

    Args:
        config (dict): config with placeholders

    Returns:
        dict: updated config
    """
    max_c = config["control"].get("max_num_classes", None)
    
    if max_c is not None:
        
        # slice tasks list to maximum length
        config["task0_labels"] = config["task0_labels"][:max_c]
        config["task1_labels"] = config["task1_labels"][:max_c]
        
        # set models 'num_classes' kwargs
        config["full_kwargs"]["num_classes"]            = 2*max_c
        config["teacher0_kwargs"]["num_classes"]        = max_c
        config["teacher1_kwargs"]["num_classes"]        = max_c
        config["student_kwargs"]["num_classes_per_head"]= max_c
        
    return config



# update_tasks_mnist function moved to task_utils.py to avoid circular imports
