import concurrent.futures
import os
import sys
sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../../"))
sys.path.append(os.path.abspath("../../../../"))  

from re import X
from src.models.gidnet.gidnet2 import GidNet2
import torch
import time
import csv
from src.models.gidnet.configs.experiment_config import ExpConfig
from src.models.gidnet.configs.gidnet_config import GidNetConfig
from src.models.base.autoencoder import Decoder,Encoder
from src.models.base.simulator_Nf2 import ForwardSimulator as Simulator
import src.utils.data as data
from semantic_loss_pytorch import SemanticLoss
import commons.semantic_loss as semloss
import concurrent
import numpy as np
from time import perf_counter
torch.set_num_threads(1)
np.set_printoptions(linewidth=np.inf) # type: ignore



# Worker function: This will be executed by each process in the pool
def worker_process_task(args):
    start = perf_counter()
    (n_seed, n_movement, lr, idx, y_target_cpu, # Task-specific params
     
     # Shared objects/configs (passed as copies by multiprocessing)
     n_layer_const, n_mat_const, N_TRIALS, config_epochs_const,
     encoder_state_dict, decoder_state_dict, simulator_state_dict, # Pass state_dicts
     encoded_train_min_cpu, encoded_train_max_cpu,
     sloss_formula_str, sloss_vtree_str, # For recreating SemanticLoss
     target_device_str) = args

    # This function runs in a separate process.
    # Set device for this worker (should be "cpu" for your 96 core CPU goal)
    worker_device = torch.device(target_device_str)

    # Re-initialize models in the worker process and load state_dicts
    # This is generally safer for multiprocessing with PyTorch, especially with CUDA involved previously
    # Ensures each process has its own model instances on the correct device.
    encoder = Encoder(n_layer_const, worker_device).to(worker_device) # type: ignore
    decoder = Decoder(n_layer_const, worker_device).to(worker_device) # type: ignore
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    encoder.eval() # Set to evaluation mode if not training these
    decoder.eval()

    # For the simulator, if it's a PyTorch nn.Module:
    simulator = Simulator(n_layer_const, n_mat_const).to(worker_device) # Assuming Simulator has a default constructor
    simulator.load_state_dict(simulator_state_dict)
    simulator.eval()

    # Recreate SemanticLoss in the worker
    # This avoids potential issues with pickling complex objects or file handles if any.
    # The files 'hyp2.sdd' and 'hyp2.vtree' must be accessible by worker processes.
    # The manager.save and vtree.save calls should happen only once in the main process before starting workers.
    current_sloss = SemanticLoss(sloss_formula_str, sloss_vtree_str) if sloss_formula_str != None else None

    # Move necessary tensors to the worker's device
    y_target = y_target_cpu.to(worker_device)
    encoded_train_min = encoded_train_min_cpu.to(worker_device)
    encoded_train_max = encoded_train_max_cpu.to(worker_device)

    gidnet = GidNet2(
        n_layer=n_layer_const,
        n_mat=n_mat_const,
        n_seeds=n_seed,
        n_movements=n_movement,
        lambda1=1, # Assuming these are fixed, or pass them if they vary
        lambda2=0,
        onehot_weight=1,
        encoder=encoder,
        decoder=decoder,
        simulator=simulator,
        device=worker_device,
        metamat_config=None, # Assuming fixed
    )
    gidnet.set_enc_train_min_max(encoded_train_min, encoded_train_max)


    # Note: torch.manual_seed(0) was inside the loop.
    # If you need different seeds per (idx, point) or per worker, adjust this.
    # For reproducibility with parallel processing, seeding can be tricky.
    # Setting it once per task might be a good compromise.
    torch.manual_seed(idx) # Seed based on material index for some variation
    
    # List of length [epochs]
    # list task_results[10] contains data of epoch 10 -> list length 128
    task_results = [[] for epoch in range(config_epochs_const)]
    for trial in range(N_TRIALS):   
        print(f"Worker {idx} processing trial {trial} ")

        # Ensure y_target is cloned for each call if gidnet.train modifies it or its grads
        history_sem, _ = gidnet.train(y_target.clone().detach(), config_epochs_const, lr, semloss=current_sloss)

        #print(f"Worker {idx} finished attempt {point} ")

        for epoch_idx, epoch_data in enumerate(history_sem):
            # Ensure data order matches header
            # header = ["Lr", "Mat_idx", "Epochs", "Trial", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
            row = [lr, idx, epoch_idx, trial, epoch_data[0], epoch_data[1], epoch_data[2], epoch_data[5]]

            task_results[epoch_idx].append(row)

        print(f"Worker {idx} finished log of trial {trial} ")

    end = perf_counter()

    # Optional: print progress from worker
    print(f"Worker finished: Lr={lr}, Mat_idx={idx}, N_seeds={n_seed}, N_mov={n_movement}, Time={(end-start):.3f}")
    return task_results












if __name__ == '__main__':
    # It's good practice for PyTorch with multiprocessing, especially if CUDA was used.
    # Use 'spawn' to avoid issues with forked processes inheriting CUDA contexts.
    torch.multiprocessing.set_start_method('spawn', force=True)

    LOG = False

    # exp_timestamp = time.strftime(f"%Y%m%d-%H%M%S_additional40material") # Not used in core loop
    exp_config = ExpConfig()
    gidnet_config = GidNetConfig()


    ## Experiments parameters
    N_TEST = 100
    NUM_WORKER = 10


    # Set to None to execute experiments without semantic loss
    experiment_loss = semloss.SemanticExperiment.PERIODIC_2

    ## Constants
    N_TRIALS = 128 
    N_LAYER_CONST = 10
    N_MAT_CONST = 7


    output_csv_file = "logs/[][No_loss]results.csv"

    if experiment_loss != None:
        output_csv_file = f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv"


    sloss_main = None
    formula_str_sl = None
    vtree_str_sl = None
    if experiment_loss != None:
        formula_str_sl = 'formula.sdd'
        vtree_str_sl = 'formula.vtree'  

        x_sl, manager_sl, vtree_sl = semloss.construct_vars(N_LAYER_CONST, N_MAT_CONST)

        # Define the formula used by semantic loss
        formula_sl = experiment_loss.get_constraint_function(N_LAYER_CONST, N_MAT_CONST)(x_sl) # type: ignore

        manager_sl.save(formula_str_sl.encode(), formula_sl) 
        vtree_sl.save(vtree_str_sl.encode())             



    cpu_device_str = "cpu"
    cpu_device = torch.device(cpu_device_str)

    print(f"Main process: Using device '{cpu_device_str}' for initial loading and then for workers.")

    train_data, val_data, test_data = data.get_x_y_data(invd_steps=N_TEST, device=cpu_device_str, val_split=None)

    X_train, Y_train = train_data
    X_test, Y_test = test_data 

    encoder_main = Encoder(N_LAYER_CONST, cpu_device_str).to(cpu_device)
    decoder_main = Decoder(N_LAYER_CONST, cpu_device_str).to(cpu_device)
    
    decoder_main.load_state_dict(torch.load('../trained_models/decoder_10layers.pt', map_location=cpu_device))
    encoder_main.load_state_dict(torch.load('../trained_models/encoder_10layers.pt', map_location=cpu_device))

    encoder_main.eval()
    decoder_main.eval()

    simulator_main = torch.load('../trained_models/simulator.pt', weights_only=False).to(cpu_device)
    simulator_main.eval()


    # Get state_dicts to pass to workers
    encoder_state_dict_cpu = encoder_main.state_dict()
    decoder_state_dict_cpu = decoder_main.state_dict()
    simulator_state_dict_cpu = simulator_main.state_dict()


    # Calculate encoded_train on CPU
    encoded_train_cpu = encoder_main(X_train.to(cpu_device))
    encoded_train_min_cpu = torch.min(encoded_train_cpu).clone().detach()
    encoded_train_max_cpu = torch.max(encoded_train_cpu).clone().detach()

   
    # Create logs directory if it doesn't exist
    os.makedirs(os.path.dirname(output_csv_file), exist_ok=True)

    header = ["Lr", "Mat_idx", "Epochs", "Trial", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]

    # Prepare list of arguments for all tasks
    tasks_args_list = []
    for n_seed_loop in gidnet_config.n_seeds:
        for n_movement_loop in gidnet_config.n_movements:
            for lr_loop in gidnet_config.lr:
                print(f"Preparing tasks for combination: N_seeds:{n_seed_loop}, N_mov:{n_movement_loop}, Lr:{lr_loop}")

                for idx_loop, y_test_item_cpu in enumerate(Y_test[:N_TEST]): 
                    current_task_params = (
                        n_seed_loop, n_movement_loop, lr_loop, idx_loop, y_test_item_cpu,
                        N_LAYER_CONST, N_MAT_CONST, N_TRIALS, gidnet_config.epochs,
                        encoder_state_dict_cpu, decoder_state_dict_cpu, simulator_state_dict_cpu,
                        encoded_train_min_cpu, encoded_train_max_cpu,
                        formula_str_sl, vtree_str_sl, # Pass filenames for SemanticLoss recreation
                        cpu_device_str # Target device for worker
                    )
                    tasks_args_list.append(current_task_params)

    all_results_from_workers = []
    
    # Determine number of workers, up to 96 or available CPUs
    num_workers = min(NUM_WORKER, os.cpu_count() if os.cpu_count() else 1) # type: ignore
    print(f"Starting ProcessPoolExecutor with {num_workers} workers...")

    with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit all tasks to the executor
        future_to_task_params = {executor.submit(worker_process_task, task_args): task_args for task_args in tasks_args_list}

        for future in concurrent.futures.as_completed(future_to_task_params):
            task_params_done = future_to_task_params[future]
            try:
                # List of length 200 (epochs), where list[100] contains data of all trials at epoch 100 
                results_for_one_task = future.result()
                all_results_from_workers.extend(results_for_one_task)
            except Exception as exc:
                lr_exc, midx_exc = task_params_done[2], task_params_done[3] # Extract some params for context
                print(f"Task (Lr={lr_exc}, Mat_idx={midx_exc}) generated an exception: {exc}")
                import traceback
                traceback.print_exc()



    print(f"All tasks completed. Writing results to CSV.")
    # Write all collected results to the CSV file at once
    with open(output_csv_file, "w", newline='') as file_loss_csv:
        csv_writer = csv.writer(file_loss_csv)
        csv_writer.writerow(header)
        if all_results_from_workers: # Check if there are any results
            for worker_data in all_results_from_workers:
                for row in worker_data:
                    csv_writer.writerow(row)
        else:
            print("No results were generated by workers.")

    print(f"Parallel processing finished. Results saved to {output_csv_file}")
