import torch
import numpy as np
import torch.multiprocessing as mp
import gc

from algorithm import proxg_update, proxg_update_row, semi_newt_update, bouligand_newt_update
from util import make_reproducibility, extract_sparse_rows_from_csr, update_sparse_rows_from_csr

def worker(rank, X, crow, col, val, finished_chunk, lamb, rho, tau_init, update_indices, update_method):
    """
    Worker function for multi-GPU processing. Updates omega using specified method
    """
    make_reproducibility(2025 + rank)
    torch.cuda.set_device(rank)
    device = torch.device(f'cuda:{rank}')

    local_rows = torch.arange(update_indices.shape[0], device=device)

    X = X.to(device)
    finished = finished_chunk.to(device)

    omega = extract_sparse_rows_from_csr(
        crow, col, val,
        X.shape[1],
        update_indices
    ).to(device)

    # extract sparse rows (omega) from COO format for the given batch
    temp = torch.sparse.mm(omega, X.T)
    d =  - torch.matmul(temp, X) / X.shape[0]
    del temp 

    update_indices = update_indices.to(device)

    # calculate active set and inactive set (+ semi active set)
    result = d + omega
    actives = (result > lamb).float() - (result < -lamb).float()
    mask = (result.abs() == lamb)
    actives[mask] = 0.5 * result[mask].sign()

    # deacitvate diagonal elements
    actives[local_rows, update_indices] = 0.0
    actives_sum = (actives.abs() != 0).sum(dim=1)

    # divide the rows according to the update method : newton or proximal
    if update_method ==  'B-semi' : 
        newt_set = local_rows[actives_sum < X.shape[0]]
        prox_set = local_rows[actives_sum >= X.shape[0]]
    else : 
        newt_set = torch.empty(0, dtype=torch.long)
        prox_set = local_rows

    omega = omega.to_dense()

    LCP_count = 0
    with torch.no_grad():
        for i in newt_set:
            if (actives[i].abs() == 0.5).sum().item() == 0:
                # semismooth newton method
                omega[i], finished[i] = semi_newt_update(
                    X, omega[i], d[i], actives[i],
                    update_indices[i].unsqueeze(0), lamb, rho, device)
            else:
                # Bouligand semismooth newton method
                LCP_count += 1
                omega[i] = bouligand_newt_update(
                    X, omega[i], d[i], actives[i],
                    update_indices[i].unsqueeze(0), lamb, rho, device)

        if prox_set.numel() > 0:
            # proximal gradient method
            omega[prox_set] = proxg_update(
                X, omega[prox_set], d[prox_set], update_indices[prox_set], lamb, tau_init, update_method, device)
            # for i in prox_set:
            #     omega[i], finished[i] = proxg_update_row(
            #         X, omega[i], d[i], update_indices[i], lamb, tau_init, update_method, device)
                
                        
    # convert updated omega to sparse format for storage
    csr_upd = omega.to_sparse_csr()
    crow_upd = csr_upd.crow_indices().cpu()
    col_upd = csr_upd.col_indices().cpu()
    val_upd = csr_upd.values().cpu()
    result_omega = (crow_upd, col_upd, val_upd)
    result_finished = finished.cpu()

    del result, actives, actives_sum, finished, X
    torch.cuda.empty_cache()
    
    return result_omega, result_finished, newt_set.numel(), prox_set.numel(), LCP_count 

def worker_batches(rank, X_shared, crow, col, val, finished, return_dict, lamb, rho, tau_init, indices_un, update_method, batch_size):
    """
    Process worker batches using multiple GPUs
    """
    make_reproducibility(2025 + rank)
    local_updates = {}
    local_newt_total = 0
    local_prox_total = 0
    local_LCP_total = 0
    num_indices = len(indices_un)

    for i in range(0, num_indices, batch_size):
        batch_indices = indices_un[i : i + batch_size]
        finished_batch = finished[batch_indices]

        # run worker to update omega batch 
        updated_omega, updated_finished, newt_count, prox_count, LCP_count = worker(
            rank, X_shared,
            crow, col, val, finished_batch,
            lamb, rho, tau_init, batch_indices, update_method
        )

        local_updates[batch_indices[0].item()] = updated_omega

        local_newt_total += newt_count
        local_prox_total += prox_count
        local_LCP_total += LCP_count

        finished[batch_indices] = updated_finished

        del updated_omega, updated_finished, finished_batch, batch_indices
        torch.cuda.empty_cache()

    return_dict[rank] = {"updates": local_updates, "newt_total": local_newt_total, "prox_total": local_prox_total, "LCP_total": local_LCP_total}

def multi_gpu(world_size, X_shared, crow, col, val, finished, num_unfinished, return_dict, lamb, rho, tau_init, unfinished_indices, update_method, batch_size) :
    """
    Launches multiple GPU processes to perform parallel updates
    """
    chunk_size = (num_unfinished + world_size - 1) // world_size
    processes = []

    for rank in range(world_size):
        # calculate the start and end indices for the current chunk
        idx_start = rank * chunk_size
        idx_end = min((rank + 1) * chunk_size, num_unfinished)
        indices_un = unfinished_indices[idx_start:idx_end]
        
        # create a new process for each GPU, targeting the worker_batches function
        p_worker = mp.Process(
            target = worker_batches,
            args=(rank, X_shared, crow, col, val,
                finished,
                return_dict, lamb, rho, tau_init, indices_un, update_method, batch_size)
        )
        p_worker.start()
        processes.append(p_worker)

    for p_worker in processes:
        p_worker.join()

def update_omega(k, world_size, X_shared, crow, col, val, finished, lamb, rho, tau_init, unfinished_indices, update_method, newt_tensor, prox_tensor, LCP_tensor, return_dict, batch_size) : 
    """
    Updates the omega either using a single GPU or multiple GPUs
    """
    num_unfinished = len(unfinished_indices)

    if num_unfinished < 3000:
        # single‐GPU path
        updated_omega, updated_finished, newt_count, prox_count, LCP_count = worker(
            0,
            X_shared, crow, col, val,
            finished[unfinished_indices],
            lamb, rho, tau_init,
            unfinished_indices, update_method
        )
        # unpack CSR from worker
        crow_upd, col_upd, val_upd = updated_omega

        # update counters & finished
        newt_tensor[k] = newt_count
        prox_tensor[k] = prox_count
        LCP_tensor[k] = LCP_count
        finished[unfinished_indices] = updated_finished

        # splice updated rows into full CSR
        crow, col, val = update_sparse_rows_from_csr(
            crow, col, val,
            unfinished_indices,
            crow_upd, col_upd, val_upd
        )

    else:
        # multi‐GPU path: launch workers
        multi_gpu(world_size,
                  X_shared, crow, col, val,
                  finished,
                  num_unfinished, return_dict,
                  lamb, rho, tau_init,
                  unfinished_indices,
                  update_method,
                  batch_size)

        # aggregate counters
        total_newt = total_prox = total_LCP = 0
        crow_chunks, col_chunks, val_chunks = [], [], []

        # return_dict keys are ranks 0..world_size-1
        for rank in range(world_size):
            worker_data = return_dict.get(rank)
            if worker_data is None:
                continue
            total_newt += worker_data["newt_total"]
            total_prox += worker_data["prox_total"]
            total_LCP  += worker_data["LCP_total"]

            # each worker_data["updates"] is (crow_upd, col_upd, val_upd)
            for crow_upd, col_upd, val_upd in worker_data["updates"].values():
                crow_chunks.append(crow_upd)
                col_chunks.append(col_upd)
                val_chunks.append(val_upd)

        newt_tensor[k] = total_newt
        prox_tensor[k] = total_prox
        LCP_tensor[k] = total_LCP

        # build a single CSR for all unfinished_indices rows
        # 1) concat per‐chunk row‐pointers into one upd_crow_full
        upd_crow_full = [0]
        for crow_upd in crow_chunks:
            offset = upd_crow_full[-1]
            for x in crow_upd[1:].tolist():
                upd_crow_full.append(x + offset)
        upd_crow_full = torch.tensor(upd_crow_full, dtype=crow.dtype)

        # 2) concat col/val
        upd_col_full = torch.cat(col_chunks, dim=0)
        upd_val_full = torch.cat(val_chunks, dim=0)

        # splice updated rows into full CSR
        crow, col, val = update_sparse_rows_from_csr(
            crow, col, val,
            unfinished_indices,
            upd_crow_full, upd_col_full, upd_val_full
        )

        gc.collect()

    return crow, col, val