import torch
import numpy as np
import torch.multiprocessing as mp
import gc

from objective_func import obj_F_square_vec, obj_accord
from util import make_reproducibility, extract_sparse_rows_from_csr

def single_gpu_obj_ftn(X, crow, col, val, star_crow, star_col, star_val, lamb, F_square_tensor, accord_tensor, opt_squared_tensor, unfinished_indices, device, batch_size):
    """
    Update objective functions for each iteration using single GPU
    """
    make_reproducibility(2025)
    torch.cuda.set_device(device)
    n, p = X.shape

    with torch.no_grad():
        X_device = X.to(device, non_blocking=True)
        crow_dev       = crow.to(device,    non_blocking=True)
        col_dev        = col.to(device,     non_blocking=True)
        val_dev        = val.to(device,     non_blocking=True)
        star_crow_dev  = star_crow.to(device,non_blocking=True)
        star_col_dev   = star_col.to(device, non_blocking=True)
        star_val_dev   = star_val.to(device, non_blocking=True)

        num_indices = len(unfinished_indices)
        for i in range(0, num_indices, batch_size):
            batch_indices = unfinished_indices[i:i+batch_size].to(device, non_blocking=True)
            # extract batch of omega (sparse matrix)
            omega_batch = extract_sparse_rows_from_csr(crow_dev, col_dev, val_dev,
                                                        p,
                                                        batch_indices,
                                                        device=device
                                                    )
            temp = torch.sparse.mm(omega_batch, X_device.T)
            d_batch = - torch.mm(temp, X_device) / n

            omega_dense = omega_batch.to_dense()

            # update objective functions
            updated_F_batch = obj_F_square_vec(X_device, omega_dense, d_batch, lamb, batch_indices)
            updated_accord = obj_accord(X_device, temp, omega_dense, lamb, batch_indices)

            del temp, omega_batch, d_batch
            torch.cuda.empty_cache()

            # calculate difference between omega and omega_star
            omega_dense -= extract_sparse_rows_from_csr(star_crow_dev, star_col_dev, star_val_dev,
                                                        p,
                                                        batch_indices,
                                                        device=device
                                                    ).to_dense()
            updated_opt_squared = torch.linalg.vector_norm((omega_dense), dim=1)**2
            
            # move value of objective functions to cpu
            F_square_tensor[batch_indices.cpu()] = updated_F_batch.cpu()
            accord_tensor[batch_indices.cpu()] = updated_accord.cpu()
            opt_squared_tensor[batch_indices.cpu()] = updated_opt_squared.cpu()

            del batch_indices, omega_dense, updated_F_batch, updated_accord, updated_opt_squared
            torch.cuda.empty_cache()

        torch.cuda.empty_cache()
        gc.collect()


def worker_multi_gpu(rank, device, X, crow, col, val, star_crow, star_col, star_val, lamb, batch_size, indices_un, F_square_tensor, accord_tensor, opt_squared_tensor):
    """
    Worker function for multi-GPU processing. Updates objective function value using specified method
    """
    make_reproducibility(2025+rank)
    torch.cuda.set_device(device)
    n, p = X.shape
    
    with torch.no_grad():
        X_device = X.to(device, non_blocking=True)
        
        for i in range(0, len(indices_un), batch_size):
            batch_indices = indices_un[i:i+batch_size]

            # extract batch of omega (sparse matrix)
            omega_batch = extract_sparse_rows_from_csr(
                                    crow, col, val,
                                    p,
                                    batch_indices
                                ).to(device)

            temp = torch.sparse.mm(omega_batch, X_device.T)
            d_batch = - torch.mm(temp, X_device) / n
            
            omega_dense = omega_batch.to_dense()
            
            # update objective functions
            updated_F_batch = obj_F_square_vec(X_device, omega_dense, d_batch, lamb, batch_indices)
            updated_accord = obj_accord(X_device, temp, omega_dense, lamb, batch_indices)

            del temp, omega_batch, d_batch
            torch.cuda.empty_cache()

            # calculate difference between omega and omega_star
            omega_dense -= extract_sparse_rows_from_csr(star_crow, star_col, star_val,
                                                        p,
                                                        batch_indices,
                                                        device=device
                                                    ).to_dense().to(device)
            updated_opt_squared = torch.linalg.vector_norm((omega_dense), dim=1)**2
            
            # move value of objective functions to cpu
            F_square_tensor[batch_indices.cpu()] = updated_F_batch.cpu()
            accord_tensor[batch_indices.cpu()] = updated_accord.cpu()
            opt_squared_tensor[batch_indices.cpu()] = updated_opt_squared.cpu()

            del batch_indices, omega_dense, updated_F_batch, updated_accord, updated_opt_squared
            torch.cuda.empty_cache()

        torch.cuda.empty_cache()
        gc.collect()
        

def multi_gpu_obj_ftn(X, crow, col, val, star_crow, star_col, star_val, lamb, F_square_tensor, accord_tensor, opt_squared_tensor, unfinished_indices, devices, batch_size):
    """
    Update objective functions for each iteration using multi GPU
    """
    num_devices = len(devices)
    num_unfinished = len(unfinished_indices)
    chunk_size = (num_unfinished + num_devices - 1) // num_devices

    processes = []
    for rank, device in enumerate(devices):
        idx_start = rank * chunk_size
        idx_end = min((rank + 1) * chunk_size, num_unfinished)
        indices_un = unfinished_indices[idx_start:idx_end]
        
        p_proc = mp.Process(target=worker_multi_gpu, args=(
            rank, device, X, crow, col, val, star_crow, star_col, star_val, lamb, batch_size, indices_un, F_square_tensor, accord_tensor, opt_squared_tensor
        ))
        p_proc.start()
        processes.append(p_proc)

    for proc in processes:
        proc.join()


def update_obj_ftn(X, crow, col, val, star_crow, star_col, star_val, lamb, F_square_tensor, accord_tensor, opt_squared_tensor, unfinished_indices, devices, batch_size):
    if len(unfinished_indices) < 20000:
        device = devices[0]
        single_gpu_obj_ftn(X, crow, col, val, star_crow, star_col, star_val, lamb, F_square_tensor, accord_tensor, opt_squared_tensor, unfinished_indices, device, batch_size)
    else:
        multi_gpu_obj_ftn(X, crow, col, val, star_crow, star_col, star_val, lamb,
                                                           F_square_tensor, accord_tensor, opt_squared_tensor,
                                                           unfinished_indices,
                                                           devices, batch_size)