import os
import sys
import json
import pickle
import itertools
import random
import logging

import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# Import your project-specific modules
from network_structure import model_definition
from model_trainining_and_evaluation import train_and_evaluate_model
from load_dataset import load_dataset
from parameter_settings import parameter_settings
from loss_functions import define_loss_function


def generate_settings_combinations(original_dict):
    """
    Generate all possible combinations of parameter settings.
    """
    list_keys = [key for key, value in original_dict.items() if isinstance(value, list)]
    combinations = list(itertools.product(*[original_dict[key] for key in list_keys]))
    result = []

    for combo in combinations:
        new_dict = original_dict.copy()
        for key, value in zip(list_keys, combo):
            new_dict[key] = value
        result.append(new_dict)

    return result


def _select_device_and_loader_params(logger):
    """
    Decide runtime device and DataLoader parameters.
    Preference order: CUDA (with pin_memory=True) → MPS → CPU.
    Returns: (device, num_workers, pin_memory)
    """
    import torch

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
        n_gpus = torch.cuda.device_count()
        logger.info(f"CUDA detected. Number of devices: {n_gpus}")
        for i in range(n_gpus):
            logger.info(f"Device {i}: {torch.cuda.get_device_name(i)}")
        num_workers = 0  # keep main-process loading unless profiling suggests otherwise
        pin_memory = True
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        num_workers = 0
        pin_memory = False
    else:
        device = torch.device("cpu")
        num_workers = 0
        pin_memory = False

    logger.info(f"Using device: {device}")
    logger.info(f"Number of workers: {num_workers}")
    return device, num_workers, pin_memory


def run_simulation(params, job_name, sim_index=0):
    """
    Execute a single simulation based on the provided parameters.
    """
    import sys
    import torch

    # Prepare arguments and logging
    args = parameter_settings(params, job_name, sim_index)

    # Fix the random seed for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set up logger for this simulation
    logger_name = f"{args.filename}_{sim_index}"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    # Remove any existing handlers to avoid duplication
    if logger.hasHandlers():
        logger.handlers.clear()

    # Create file handler which logs messages
    fh = logging.FileHandler(os.path.join(args.results_dir, f'{args.filename}.log'))
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # Add a console handler
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # Print the arguments at the beginning of the simulation
    logger.info(f"Starting simulation with arguments:\n{args}")

    device, num_workers, pin_memory = _select_device_and_loader_params(logger)

    # Data preparation
    train_dataset, test_dataset, embedding_weights, args = load_dataset(args)

    # Create loaders
    loaders = {
        "train_loader": DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=num_workers, pin_memory=pin_memory
        ),
        "train_loader_fixed": DataLoader(
            train_dataset, batch_size=args.fixed_batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory
        ),
        "test_loader": DataLoader(
            test_dataset, batch_size=args.fixed_batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory
        ),
        "train_loader_corrupted": DataLoader(
            Subset(train_dataset, args.corrupted_samples), batch_size=args.fixed_batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory
        ),
        "train_loader_not_corrupted": DataLoader(
            Subset(train_dataset, args.not_corrupted_samples), batch_size=args.fixed_batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory
        )
    }

    # Add train_loader_known_corrupted if there are new_indexes
    if len(train_dataset.new_indexes) > 0:
        loaders["train_loader_known_corrupted"] = DataLoader(
            Subset(train_dataset, train_dataset.new_indexes), batch_size=args.fixed_batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory
        )

    # Model definition
    model, args = model_definition(device, args, embedding_weights=embedding_weights)

    criterion = define_loss_function(args)

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov
    )

    if getattr(args, "lr_policy", None) == "lr_multistep":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_milestones, gamma=args.lr_gamma
        )
    elif getattr(args, "lr_policy", None) == "cosine_decay":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs, eta_min=0
        )
    else:
        scheduler = None  # constant LR

    # Save simulation arguments
    with open(os.path.join(args.results_dir, "args.pkl"), "wb") as f:
        pickle.dump(args, f)

    # Train and evaluate
    model, performances, args = train_and_evaluate_model(
        model, loaders, criterion, optimizer, scheduler, device, args, logger
    )

    # Save results
    performances.torch_transformation()
    performances.plot_performances(args)
    performances.evaluate_memorization_metrics(args, window_size=5)

    with open(os.path.join(args.results_dir, "performances.pkl"), "wb") as f:
        pickle.dump(performances, f)

    # Save the trained model
    model_path = os.path.join(args.results_dir, f"{args.filename}_model.pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'args': vars(args)
    }, model_path)
    logger.info(f"Model saved at: {model_path}")

    logger.info(f"SIMULATION COMPLETED: {args.filename}")

    # Clean up handlers to prevent issues in future simulations
    logger.removeHandler(fh)
    logger.removeHandler(ch)

    logger.info(f"SIMULATION COMPLETED: {args.filename}")


def main(job_name):
    """
    Entry point for running simulations (sequential-only).
    """
    # Load job configuration
    file_name = f"programs/simulations/{job_name}.json"
    with open(file_name, 'r') as f:
        json_dict = json.load(f)

    # We ignore any parallelization flags; everything runs sequentially
    json_dict_comb_list = generate_settings_combinations(json_dict)

    print(f"Total simulations to run: {len(json_dict_comb_list)}")
    print("Running simulations sequentially...")

    for i, params in enumerate(json_dict_comb_list):
        print(f"\nSimulation {i+1}/{len(json_dict_comb_list)}\n")
        run_simulation(params, job_name, sim_index=i)


if __name__ == '__main__':
    # Adjust working directory
    if "program" in os.getcwd():
        os.chdir("..")

    print("Start!")
    if len(sys.argv) > 1:
        job_name = sys.argv[1]
    else:
        print("Please provide a job name as an argument.")
        sys.exit(1)

    main(job_name)