import gc
import json
import os
import datetime
import subprocess
from huggingface_hub import login
import torch
from transformers import AutoModelForCausalLM

# Default values, consistent with scripts
DEFAULT_HF_CACHE = "/root/autodl-tmp/huggingface"
DEFAULT_HF_TOKEN = ""

task_specific_epochs = {
    'C-STANCE': 5,
    'FOMC': 3,
    'MeetingBank': 7,
    'ScienceQA': 3,
    'NumGLUE-cm': 5,
    'NumGLUE-ds': 5,
    '20Minuten': 7
}
task_specific_batch_sizes = {
    'C-STANCE': 8,
    'FOMC': 8,
    'MeetingBank': 8,
    'ScienceQA': 8,
    'NumGLUE-cm': 8,
    'NumGLUE-ds': 8,
    '20Minuten': 8
}


def setup_conda_environments():
    """
    Ensure all necessary conda environments are set up
    """
    print("\n===== Checking and setting up conda environments =====")

    # Check if conda is available
    try:
        subprocess.run("conda --version", shell=True,
                       check=True, stdout=subprocess.DEVNULL)
    except:
        print("Warning: conda command not found, please ensure conda is installed and added to PATH")
        return False

    # Check if necessary environments exist
    required_envs = ["lmeval", "bigcode", "safety-eval"]
    missing_envs = []

    for env in required_envs:
        try:
            result = subprocess.run(
                f"conda env list | grep {env}",
                shell=True,
                check=False,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            if not result.stdout.strip():
                missing_envs.append(env)
        except:
            missing_envs.append(env)

    # If environments are missing, provide creation commands
    if missing_envs:
        print(f"Warning: The following conda environments do not exist: {', '.join(missing_envs)}")
        print("Please use the following commands to create environments:")

        if "lmeval" in missing_envs:
            print("""
conda create -n lmeval python=3.10
conda activate lmeval
pip install lm-eval
pip install numpy==1.24.3  # Fix numpy version conflict
            """)

        return False

    return True


def create_experiment_dir(base_path="./experiments/"):
    """
    Create experiment directory with timestamp

    Args:
        base_path: Base path

    Returns:
        str: Experiment directory path
    """
    os.makedirs(base_path, exist_ok=True)

    # Create directory name with timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    dir_path = os.path.join(base_path, f"continual_learning_{timestamp}")
    os.makedirs(dir_path, exist_ok=True)

    # Create necessary subdirectories
    os.makedirs(os.path.join(dir_path, "finetunedModels"), exist_ok=True)

    return dir_path


def setup_hf_cache(cache_dir=None):
    """
    Set HuggingFace cache directory

    Args:
        cache_dir: Cache directory path, use default if None
    """
    # Use parameter value or default
    cache_dir = cache_dir or DEFAULT_HF_CACHE

    # Create necessary directories
    os.makedirs(cache_dir, exist_ok=True)

    # Set all possible cache environment variables
    os.environ['HF_HOME'] = cache_dir
    os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers')
    os.environ['HF_DATASETS_CACHE'] = os.path.join(cache_dir, 'datasets')
    os.environ['HUGGINGFACE_HUB_CACHE'] = os.path.join(cache_dir, 'hub')
    os.environ['XDG_CACHE_HOME'] = cache_dir

    os.environ['TRANSFORMERS_OFFLINE'] = "0"  # Ensure non-offline mode
    os.environ['PYTHONWARNINGS'] = "ignore::UserWarning"
    os.environ['PYTHONIOENCODING'] = "utf-8"
    os.environ['LC_ALL'] = "C.UTF-8"
    os.environ['LANG'] = "C.UTF-8"
    print(f"HuggingFace cache directory set: {cache_dir}")


def huggingface_login(token=None):
    """
    Login to HuggingFace

    Args:
        token: API token, use default if None

    Returns:
        bool: Whether login was successful
    """
    # Use parameter value or default
    token = token or DEFAULT_HF_TOKEN

    try:
        print("Logging in to HuggingFace...")
        # Ensure token is written to config file
        os.makedirs(os.path.expanduser("~/.huggingface"), exist_ok=True)
        with open(os.path.expanduser("~/.huggingface/token"), "w") as f:
            f.write(token)
        os.chmod(os.path.expanduser("~/.huggingface/token"), 0o600)

        # Also set environment variables
        os.environ['HF_TOKEN'] = token
        os.environ['HUGGING_FACE_HUB_TOKEN'] = token

        # Explicit login
        login(token=token)
        print("HuggingFace login successful")
        return True
    except Exception as e:
        print(f"HuggingFace login failed: {e}")
        return False


def save_model_state_to_disk(state_dict, directory):
    """Save model state to disk"""
    state_path = os.path.join(directory, "previous_model_state.pt")
    print(f"Saving model state to: {state_path}")
    torch.save(state_dict, state_path)
    return state_path


def load_model_state_from_disk(directory):
    """Load model state from disk"""
    state_path = os.path.join(directory, "previous_model_state.pt")
    if os.path.exists(state_path):
        print(f"Loading model state from disk: {state_path}")
        return torch.load(state_path, map_location="cpu")
    return None


def save_merge_config(result_dir, merge_method, base_model, task_names,
                      start_task, end_task, scaling_coef=1.0, use_default_scaling=False,
                      task_vector_from_base=True,  # Add parameter, default True
                      continue_experiment=False, prev_experiment_dir=None,
                      final_model_only=False):
    """
    Record and save configuration information for model merging experiments

    Args:
        result_dir: Result save directory
        merge_method: Merge method name
        base_model: Base model path
        task_names: Task name list
        start_task: Start task index
        end_task: End task index
        scaling_coef: User-specified scaling coefficient
        use_default_scaling: Whether to use default scaling coefficient
        task_vector_from_base: Whether to compute task vector from base model
        continue_experiment: Whether to continue previous experiment
        prev_experiment_dir: Previous experiment directory
        final_model_only: Whether to save only the final model

    Returns:
        str: Config file path
    """
    config_path = os.path.join(result_dir, "merge_config.json")
    if not os.path.exists(config_path) or not continue_experiment:
        # Determine the actual scaling coefficient used
        default_scaling = {
            "adaptive_iso": 1.0,
            "task_arithmetic": 0.4,
            "magmax": 0.8,
            "ties_merge": 0.4,
            "dare": 0.4,
            "swa": "NULL"  # SWA does not use scaling coefficient
        }
        actual_scaling = default_scaling.get(
            merge_method, 1.0) if use_default_scaling else scaling_coef

        config_info = {
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "base_model": base_model,
            "tasks": task_names[start_task:end_task+1],
            "task_vector_from_base": task_vector_from_base, 
            "merge_method": merge_method,
            "merge_params": {
                "scaling_coef": actual_scaling,
                "use_default_scaling": use_default_scaling
            },
            "continue_from": prev_experiment_dir if continue_experiment else None,
            "final_model_only": final_model_only
        }

        with open(config_path, "w") as f:
            json.dump(config_info, f, indent=2)
        print(f"Merge config saved to: {config_path}")

    return config_path


def restore_merge_state(merge_method, result_dir, prev_task_idx, previous_model_state, cache_dir):
    """
    Restore previous merge state based on different merge methods

    Args:
        merge_method: Merge method name
        result_dir: Result save directory
        prev_task_idx: Index of previous task
        previous_model_state: Global dictionary for storing model state
        cache_dir: Cache directory

    Returns:
        bool: Whether state restoration was successful
    """
    if prev_task_idx < 0:
        print("No previous task, no need to restore state")
        return False

    success = False

    if merge_method == 'adaptive_iso':
        # First try to load from saved state file
        state_dict = load_model_state_from_disk(result_dir)
        if state_dict is not None:
            previous_model_state["current_state"] = state_dict
            print("Successfully loaded previous merged model state from state file")
            success = True
        else:
            # Fallback to try loading from saved model
            prev_model_path = None

            # First try to load model after specific task
            task_model_path = os.path.join(
                result_dir, f"model_after_task_{prev_task_idx+1}")
            if os.path.exists(task_model_path) and os.path.isdir(task_model_path):
                prev_model_path = task_model_path
                print(f"Found model from previous task: {prev_model_path}")

            # If previous model is found, load it
            if prev_model_path:
                print(f"Loading previous merged model state: {prev_model_path}")
                try:
                    # Load model
                    prev_model = AutoModelForCausalLM.from_pretrained(
                        prev_model_path,
                        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                        device_map="auto" if torch.cuda.is_available() else None,
                        low_cpu_mem_usage=True,
                        cache_dir=os.path.join(cache_dir, "transformers")
                    )

                    # Save model state using standard key name
                    previous_model_state["current_state"] = prev_model.state_dict(
                    )
                    print("Successfully loaded previous merged model state")
                    success = True

                    # Release memory
                    del prev_model
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                except Exception as e:
                    print(f"Warning: Failed to load previous model: {e}")
                    print("Will start from base model")
            else:
                print("Previous merged model not found, will start from base model")

    elif merge_method == 'task_arithmetic':
        print("Attempting to load cumulative task vector...")
        from merging.task_arithmetic import load_cumulative_vector_from_disk
        loaded = load_cumulative_vector_from_disk(result_dir)
        if loaded:
            print("Successfully loaded cumulative task vector")
            success = True
        else:
            print("Cumulative task vector not found, will accumulate from scratch")

    elif merge_method == 'magmax':
        print("Attempting to load MagMax cumulative task vector...")
        from merging.magmax import load_magmax_vector_from_disk
        loaded = load_magmax_vector_from_disk(result_dir)
        if loaded:
            print("Successfully loaded MagMax cumulative task vector")
            success = True
        else:
            print("MagMax cumulative task vector not found, will accumulate from scratch")

    elif merge_method == 'ties_merge':
        print("Attempting to load TIES cumulative task vector...")
        from merging.ties_merge import load_ties_vector_from_disk
        loaded = load_ties_vector_from_disk(result_dir)
        if loaded:
            print("Successfully loaded TIES cumulative task vector")
            success = True
        else:
            print("TIES cumulative task vector not found, will accumulate from scratch")

    elif merge_method == 'dare':
        print("Attempting to load DARE cumulative task vector...")
        from merging.dare import load_dare_vector_from_disk
        loaded = load_dare_vector_from_disk(result_dir)
        if loaded:
            print("Successfully loaded DARE cumulative task vector")
            success = True
        else:
            print("DARE cumulative task vector not found, will accumulate from scratch")
    elif merge_method == 'swa':
        print("Attempting to load SWA cumulative state...")
        from merging.swa import load_swa_state_from_disk
        loaded = load_swa_state_from_disk(result_dir)
        if loaded:
            print("Successfully loaded SWA cumulative state")
            success = True
        else:
            print("SWA cumulative state not found, will accumulate from scratch")
    else:
        print(f"Unknown merge method: {merge_method}")

    return success