import concurrent.futures
import os
import sys
sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../../"))
sys.path.append(os.path.abspath("../../../../"))  

from re import X
from methods.gidnet.gidnet import GidNet
import torch
import time
import csv
from configs.metamat_ds_config import MetamatDsConfig
from configs.experiment_config import ExpConfig
from configs.gidnet_config import GidNetConfig
from data_utils.load_data import get_x_rt_data
from base_models.autoencoder import Encoder,Decoder
from base_models.simulator_Nf import ForwardSimulator
from semantic_loss_pytorch import SemanticLoss
import commons.semantic_loss as semloss
import concurrent
import numpy as np
from time import perf_counter
torch.set_num_threads(1)
np.set_printoptions(linewidth=np.inf) # type: ignore



# Worker function: This will be executed by each process in the pool
def worker_process_task(args):
    start = perf_counter()
    (_metamat_config, n_seed, n_movement, lr, idx, y_target_cpu, x_target_cpu, # Task-specific params
     
     # Shared objects/configs (passed as copies by multiprocessing)
     n_layer_const, n_mat_const, N_TRIALS, config_epochs_const,
     encoder_state_dict, decoder_state_dict, simulator_state_dict, # Pass state_dicts
     encoded_train_min_cpu, encoded_train_max_cpu,
     sloss_formula_str, sloss_vtree_str, # For recreating SemanticLoss
     target_device_str) = args

    # This function runs in a separate process.
    # Set device for this worker (should be "cpu" for your 96 core CPU goal)
    worker_device = torch.device(target_device_str)

    # Re-initialize models in the worker process and load state_dicts
    # This is generally safer for multiprocessing with PyTorch, especially with CUDA involved previously
    # Ensures each process has its own model instances on the correct device.
    encoder = Encoder(n_layer_const, n_mat_const, worker_device).to(worker_device) # type: ignore
    decoder = Decoder(n_layer_const, n_mat_const, worker_device).to(worker_device) # type: ignore
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    encoder.eval() # Set to evaluation mode if not training these
    decoder.eval()

    # For the simulator, if it's a PyTorch nn.Module:
    simulator = ForwardSimulator(n_layer_const, layer_config=2).to(worker_device) # Assuming Simulator has a default constructor
    simulator.load_state_dict(simulator_state_dict)
    simulator.eval()

    # Recreate SemanticLoss in the worker
    # This avoids potential issues with pickling complex objects or file handles if any.
    # The files 'hyp2.sdd' and 'hyp2.vtree' must be accessible by worker processes.
    # The manager.save and vtree.save calls should happen only once in the main process before starting workers.
    current_sloss = SemanticLoss(sloss_formula_str, sloss_vtree_str) if sloss_formula_str != None else None

    # Move necessary tensors to the worker's device
    y_target = y_target_cpu.to(worker_device)
    x_target = x_target_cpu.to(worker_device)
    encoded_train_min = encoded_train_min_cpu.to(worker_device)
    encoded_train_max = encoded_train_max_cpu.to(worker_device)

    gidnet = GidNet(
        n_layer=n_layer_const,
        n_mat=n_mat_const,
        n_seeds=n_seed,
        n_movements=n_movement,
        lambda1=1, # Assuming these are fixed, or pass them if they vary
        lambda2=0,
        onehot_weight=1,
        encoder=encoder,
        decoder=decoder,
        simulator=simulator,
        device=worker_device,
        metamat_config=_metamat_config, # Pass the metamat_config
    )
    gidnet.set_enc_train_min_max(encoded_train_min, encoded_train_max)


    # Note: torch.manual_seed(0) was inside the loop.
    # If you need different seeds per (idx, point) or per worker, adjust this.
    # For reproducibility with parallel processing, seeding can be tricky.
    # Setting it once per task might be a good compromise.
    torch.manual_seed(idx) # Example: seed based on material index for some variation
    
    # List of length [epochs]
    # list task_results[10] contains data of epoch 10 -> list length 128
    task_results = [[] for epoch in range(config_epochs_const)]
    for trial in range(N_TRIALS):   
        print(f"Worker {idx} processing trial {trial} ")

        # Ensure y_target is cloned for each call if gidnet.train modifies it or its grads
        seeds = gidnet.enc_train_min + (gidnet.enc_train_max - gidnet.enc_train_min) * torch.rand(gidnet.n_seeds, n_layer_const*3).to(worker_device) # type: ignore
        history_sem = gidnet.train(y_target, x_target, seeds, n_epoch=config_epochs_const, lr=lr, semloss=current_sloss, LOG=False) # type: ignore


        #print(f"Worker {idx} finished attempt {point} ")

        for epoch_idx, epoch_data in enumerate(history_sem):
            # Ensure data order matches header
            # header = ["Lr", "Mat_idx", "Epochs", "Trial", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
            row = [lr, idx, epoch_idx, trial, epoch_data[0], epoch_data[1], epoch_data[2], epoch_data[5]]

            task_results[epoch_idx].append(row)

        print(f"Worker {idx} finished log of trial {trial} ")

    end = perf_counter()

    # Optional: print progress from worker
    print(f"Worker finished: Lr={lr}, Mat_idx={idx}, N_seeds={n_seed}, N_mov={n_movement}, Time={(end-start):.3f}")
    return task_results












if __name__ == '__main__':
    # It's good practice for PyTorch with multiprocessing, especially if CUDA was used.
    # Use 'spawn' to avoid issues with forked processes inheriting CUDA contexts.
    torch.multiprocessing.set_start_method('spawn', force=True)

    # The print for each epoch of the inverse computation
    LOG = False

    # ---- CREATE LOG FILES -------
    if not os.path.exists("logs/"):
        os.makedirs("logs/")

    # exp_timestamp = time.strftime(f"%Y%m%d-%H%M%S_additional40material") # Not used in core loop
    exp_config = ExpConfig()
    gidnet_config = GidNetConfig()
    metamat_config = MetamatDsConfig()

    ## Experiments parameters
    N_TEST = 500
    NUM_WORKER = 10

    # Set to None to execute experiments without semantic loss
    experiment_loss = semloss.SemanticExperiment.PERIODIC_2


    ## CONSTANTS, Do not change
    N_TRIALS = 128
    N_LAYER_CONST = 5
    N_MAT_CONST = 5
    

    output_csv_file = "logs/[][No_loss]results.csv"

    if experiment_loss != None:
        output_csv_file = f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv"


    sloss_main = None
    formula_str_sl = None
    vtree_str_sl = None
    if experiment_loss != None:
        formula_str_sl = 'formula.sdd'
        vtree_str_sl = 'formula.vtree'  

        x_sl, manager_sl, vtree_sl = semloss.construct_vars(N_LAYER_CONST, N_MAT_CONST)

        # Define the formula used by semantic loss
        formula_sl = experiment_loss.get_constraint_function(N_LAYER_CONST, N_MAT_CONST)(x_sl) # type: ignore

        manager_sl.save(formula_str_sl.encode(), formula_sl) 
        vtree_sl.save(vtree_str_sl.encode())             

    

    cpu_device_str = "cpu"
    cpu_device = torch.device(cpu_device_str)

    print(f"Main process: Using device '{cpu_device_str}' for initial loading and then for workers.")

    # invd_num_test must be = 500
    X_train, X_test, Y_train, Y_test = get_x_rt_data(metamat_config, exp_config.invd_num_test)


    X_test = torch.tensor(X_test, dtype=torch.float32)
    X_train = torch.tensor(X_train, dtype=torch.float32)
    Y_train = torch.tensor(Y_train, dtype=torch.float32)
    Y_test = torch.tensor(Y_test, dtype=torch.float32)

    encoder_main = Encoder(N_LAYER_CONST, N_MAT_CONST, cpu_device_str).to(cpu_device)
    decoder_main = Decoder(N_LAYER_CONST, N_MAT_CONST, cpu_device_str).to(cpu_device)
    decoder_main.load_state_dict(torch.load('../../trained_models/autoencoder/decoderLR=0.001_BS=1024_L2W-RW=3-1.pt', map_location=cpu_device))
    encoder_main.load_state_dict(torch.load('../../trained_models/autoencoder/encoderLR=0.001_BS=1024_L2W-RW=3-1.pt', map_location=cpu_device))

    encoder_main.eval()
    decoder_main.eval()
    simulator_main = ForwardSimulator(N_LAYER_CONST, layer_config=2).to(cpu_device)
    simulator_main.load_state_dict(torch.load('../../trained_models/simulator_noscale/Nf_175200train_43800val_5matlay_150e_1024b_0.005lr.pt', map_location=cpu_device, weights_only=False))
    simulator_main.eval()


    # Get state_dicts to pass to workers
    encoder_state_dict_cpu = encoder_main.state_dict()
    decoder_state_dict_cpu = decoder_main.state_dict()
    simulator_state_dict_cpu = simulator_main.state_dict()


    # Calculate encoded_train on CPU
    encoded_train_cpu = encoder_main(X_train.to(cpu_device))
    encoded_train_min_cpu = torch.min(encoded_train_cpu, dim=0).values.clone().detach()
    encoded_train_max_cpu = torch.max(encoded_train_cpu, dim=0).values.clone().detach()

   
    # Create logs directory if it doesn't exist
    os.makedirs(os.path.dirname(output_csv_file), exist_ok=True)

    header = ["Lr", "Mat_idx", "Epochs", "Trial", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]

    # Prepare list of arguments for all tasks
    tasks_args_list = []
    for n_seed_loop in gidnet_config.n_seeds:
        for n_movement_loop in gidnet_config.n_movements:
            for lr_loop in gidnet_config.lr:
                print(f"Preparing tasks for combination: N_seeds:{n_seed_loop}, N_mov:{n_movement_loop}, Lr:{lr_loop}")

                for idx_loop, y_test_item_cpu in enumerate(Y_test[:N_TEST]): 
                    current_task_params = (metamat_config,
                        n_seed_loop, n_movement_loop, lr_loop, idx_loop, y_test_item_cpu, X_test[idx_loop], # Task-specific params
                        N_LAYER_CONST, N_MAT_CONST, N_TRIALS, gidnet_config.epochs,
                        encoder_state_dict_cpu, decoder_state_dict_cpu, simulator_state_dict_cpu,
                        encoded_train_min_cpu, encoded_train_max_cpu,
                        formula_str_sl, vtree_str_sl, # Pass filenames for SemanticLoss recreation
                        cpu_device_str # Target device for worker
                    )
                    tasks_args_list.append(current_task_params)

    all_results_from_workers = []
    
    # Determine number of workers, up to 96 or available CPUs
    num_workers = min(NUM_WORKER, os.cpu_count() if os.cpu_count() else 1) # type: ignore
    print(f"Starting ProcessPoolExecutor with {num_workers} workers...")

    with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit all tasks to the executor
        future_to_task_params = {executor.submit(worker_process_task, task_args): task_args for task_args in tasks_args_list}

        for future in concurrent.futures.as_completed(future_to_task_params):
            task_params_done = future_to_task_params[future]
            try:
                # List of length 200 (epochs), where list[100] contains data of all trials at epoch 100 
                results_for_one_task = future.result()
                all_results_from_workers.extend(results_for_one_task)
            except Exception as exc:
                lr_exc, midx_exc = task_params_done[2], task_params_done[3] # Extract some params for context
                print(f"Task (Lr={lr_exc}, Mat_idx={midx_exc}) generated an exception: {exc}")
                import traceback
                traceback.print_exc()



    print(f"All tasks completed. Writing results to CSV.")
    # Write all collected results to the CSV file at once
    with open(output_csv_file, "w", newline='') as file_loss_csv:
        csv_writer = csv.writer(file_loss_csv)
        csv_writer.writerow(header)
        if all_results_from_workers: # Check if there are any results
            for worker_data in all_results_from_workers:
                for row in worker_data:
                    csv_writer.writerow(row)
        else:
            print("No results were generated by workers.")

    print(f"Parallel processing finished. Results saved to {output_csv_file}")
