import contextlib
import io
import itertools
import math
import multiprocessing
import os
import pickle
import traceback

import torch
from anal.log import LenT
from tool.args import get_general_args, post_processing_args, pre_processing_args
from tool.util import init_wandb
from main_stab import main
from datetime import datetime
import pandas as pd


##################################################################
""" CUSTOM CONFIG """
SIGMA_W_LIST = [1.0] # , 1.25, 1.5]
SIGMA_B_LIST = [0.1]
T_LIST = [10]
Z_INIT_LIST = ['gaussian']
L_LIST = [30]  # New list to be included in tasks
METHOD_LIST = ['pcn', 'spcn']  # Models to run
DATASET_ARCH_LIST = [('random', 'fc'), ('mnist', 'fc')]
BSZ_LIST = [128]
N_GPU = 4  # Number of GPUs to use
N_RUNS = 1  # Number of runs
ETA_LIST = [0.05]  # eta value
WANDB_ENTITY = 'bml_pc'
TRAIN = False
PROJ_NAME = 'table_explosion'
RESULTS_FILE = "results.txt"  # Output file
csv_lock = multiprocessing.Lock()
MARKER = 'lyapA'
CSV_NAME = f"{MARKER}_{datetime.now().strftime('%y%m%d_%H%M')}.csv"

# Checkpoint options
# LAST_CLS = True
LOSS_SUM = True
# ACT_FIRST = True

###

# N_RUNS = 1
# N_GPU = 1
# Z_INIT_LIST = ['use_db']
# L_LIST = [3]  # New list to be included in tasks
# METHOD_LIST = ['pcn']  # Models to run
# DATASET_ARCH_LIST = [('cifar10', 'cnn')]


##################################################################

def run_experiment(sigma_w, sigma_b, method, dataset_arch, z_init, t, l, eta, bsz):
    """ Function to run a single experiment with the provided arguments """
    args = pre_processing_args()

    # Assign experiment parameters to args
    args.sigma_w = sigma_w
    args.sigma_b = sigma_b
    args.dataset = dataset_arch[0]
    args.arch = dataset_arch[1]
    args.eta = eta
    args.wandb_entity = WANDB_ENTITY
    args.train = TRAIN
    args.proj_name = PROJ_NAME
    # args.epochs = 30 #l * 10
    args.z_init = z_init
    args.T = t
    # args.last_cls = LAST_CLS
    args.loss_sum = LOSS_SUM
    # args.act_first = ACT_FIRST
    args.n_layers = l  # Add L parameter
    args.bsz = bsz
    args.latent_dim = 100
    args.lyap = True
    #
    if args.arch == 'fc':
        args.comp_eta = True
        if args.dataset == 'mnist':
            args.act = 'tanh'
        elif args.dataset == 'random':
            args.act = 'linear'
    elif args.arch == 'cnn':
        args.step_eta = True
        args.act = 'relu'

    if method == 'spcn':
        args.method = 'pcd'
        args.w_reg = True #HJING
        args.z_reg = True #HJING
        args.b_reg = True #HJING
    elif method == 'spcn-r':
        args.method = 'pcd'
    elif method == 'pcd2':
        args.method = 'pcd2'
    elif method == 'pcn':
        args.method = 'pc'
    elif method == 'pcn+r':
        args.method = 'pc'
        args.w_reg = True
        args.z_reg = True
        args.b_reg = True
    
    args.exp = f"{MARKER}_{method}_{args.dataset}_{args.arch}_{args.z_init}_T_{args.T}_L_{args.n_layers}_sw_{args.sigma_w}_sb_{args.sigma_b}_et_{args.eta}_ep_{args.epochs}_bsz_{args.bsz}_{datetime.now().strftime('%y%m%d_%H%M')}"
    # Post-process args and initialize wandb
    args = post_processing_args(args)
    init_wandb(args)

    args_cpy = args

    # Run the main experiment
    lenT = main(args, sigma_ws=[sigma_w], sigma_bs=[sigma_b], etas=[eta])
    # Model stored at args.chkpt_path

    # Log results to CSV
    log_to_csv(args, args_cpy)

    return lenT


def log_to_csv(args, args_cpy):
    """ Log experiment results to a CSV file """
    with csv_lock:
        csv_file_path = f'./run_results/{CSV_NAME}'

        data = {
            # 'util_best_acc': [args.util_best_acc],
            # 'util_acc': [args.util_acc],
            'util_z_lyap_m': [str(args.util_z_lyap_m)],
            'util_z_lyap_s': [str(args.util_z_lyap_s)],

            'util_d_lyap_m': [str(args.util_d_lyap_m)],
            'util_d_lyap_s': [str(args.util_d_lyap_s)],

            # 'util_w_lyap_m': [str(args.util_w_lyap_m)],
            # 'util_w_lyap_s': [str(args.util_w_lyap_s)]
        }
        # z_lyap_m, z_lyap_s, d_lyap_m, d_lyap_s, w_lyap_m, w_lyap_s

        for key, value in args_cpy.__dict__.items():
            data[key] = [str(value)] if isinstance(value, list) else [value]

        new_row_df = pd.DataFrame(data)

        if not os.path.exists(csv_file_path):
            new_row_df.to_csv(csv_file_path, index=False)
        else:
            existing_df = pd.read_csv(csv_file_path)
            updated_df = pd.concat([existing_df, new_row_df], ignore_index=True)
            updated_df.to_csv(csv_file_path, index=False)


def experiment_worker(queue, gpu_id, lock):
    """ Worker function to process experiments in the queue on the assigned GPU """
    # run_experiment(sigma_w, sigma_b, method, dataset_arch, z_init, t, l, eta, bsz)

    while not queue.empty():
        try:
            sigma_w, sigma_b, method, dataset_arch, z_init, t, l, eta, bsz = queue.get_nowait()
            os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

            for _ in range(N_RUNS):  # Run each experiment N_RUNS times
                with contextlib.redirect_stdout(io.StringIO()):
                    run_experiment(sigma_w, sigma_b, method, dataset_arch, z_init, t, l, eta, bsz)

            log_to_file(lock, f"Complete on GPU {gpu_id}: {method}, {dataset_arch}, L={l}, sigma_w={sigma_w}, sigma_b={sigma_b}, z_init={z_init}")

        except Exception as e:
            error_message = f"Error on GPU {gpu_id}: {method}, {dataset_arch}, L={l}, sigma_w={sigma_w}, sigma_b={sigma_b}, z_init={z_init} - {traceback.format_exc()}"
            log_to_file(lock, error_message)


def log_to_file(lock, message):
    """ Safely log messages to a file using a lock """
    with lock:
        with open(RESULTS_FILE, "a") as f:
            f.write(message + "\n")


if __name__ == "__main__":
    task_queue = multiprocessing.Queue()

    # all possible combinations
    combinations = list(itertools.product(SIGMA_W_LIST, SIGMA_B_LIST, METHOD_LIST, DATASET_ARCH_LIST, Z_INIT_LIST, T_LIST, L_LIST, ETA_LIST, BSZ_LIST))
    for combo in combinations:
        task_queue.put(combo)
        file_lock = multiprocessing.Lock()

    # Create workers for each GPU
    workers = []
    for gpu_id in range(N_GPU):
        worker = multiprocessing.Process(target=experiment_worker, args=(task_queue, gpu_id, file_lock))
        workers.append(worker)
        worker.start()

    for worker in workers:
        worker.join()

    print("All experiments completed.")
