import torch
import torch.nn as nn
import gc

def parallel_linear_regression(x, y, n_samples, n_samples_train, device, args, n_epochs=500, early_stopping=50, sparse=False, verbose=False):
    # Declare multi_gpu as global so it can be accessed
    #global multi_gpu
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    import tqdm
    
    #y_mean = y[n_samples_train:n_samples].mean(dim=0).cpu()
    y_mean = y.mean(dim=0).cpu()

    # Adjust batch size if using multiple GPUs
    batch_size = 128
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
        batch_size = batch_size * num_gpus
        print(f"Using batch size {batch_size} for linear regression with {num_gpus} GPUs")

    # loaders
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[:n_samples_train], y[:n_samples_train]), batch_size=min(batch_size, n_samples_train), shuffle=True)
    val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[n_samples_train:n_samples], y[n_samples_train:n_samples]), batch_size=min(batch_size, n_samples - n_samples_train), shuffle=False)

    # set up a linear layer to use for parallel regression - explicitly use float32
    linear = nn.Linear(x.shape[1], y.shape[1]).to(device).float()
    
    # Enable multi-GPU for linear model if available
    if multi_gpu:
        try:
            # Make sure model is on the right device before wrapping
            linear = linear.to(device)
            for param in linear.parameters():
                if param.device != device:
                    param.data = param.data.to(device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            if args.gpu_ids:
                # DataParallel uses indices starting from 0 after CUDA_VISIBLE_DEVICES is set
                num_gpus = len(args.gpu_ids.split(','))
                linear = nn.DataParallel(linear, device_ids=list(range(num_gpus)))
            else:
                linear = nn.DataParallel(linear)
        except Exception as e:
            print(f"Failed to use DataParallel for linear model: {e}")
            print(f"Falling back to single GPU")
            linear = linear.to(device)
    
    optimizer = torch.optim.Adam(linear.parameters(), lr=0.0001, weight_decay=0)
    loss_fn = nn.MSELoss()

    # train the linear layer
    val_losses = []
    if verbose:
        pbar = tqdm.tqdm(range(n_epochs))
    else:
        pbar = range(n_epochs)
    for epoch in pbar:
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()  # Force float32
            optimizer.zero_grad()
            y_pred = linear(x_batch)
            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            optimizer.step()
        val_loss = 0
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device).float(), y_val.to(device).float()  # Force float32
            with torch.no_grad():
                y_pred = linear(x_val)
                val_loss += loss_fn(y_pred, y_val).item()
        val_losses.append(val_loss / len(val_loader))
        if verbose:
            pbar.set_description(f'Epoch {epoch+1}/{n_epochs}, Val Loss: {val_loss / len(val_loader):.4f}')
        if epoch > early_stopping and min(val_losses[-early_stopping:]) > min(val_losses):
            if verbose:
                print("Early stopping in linear regression at epoch ", epoch)
            break
    
    # When using DataParallel for prediction, we need to handle it differently
    if multi_gpu:
        # Process in batches to avoid OOM
        all_preds = []
        #test_data = x[n_samples_train:n_samples].to(device)
        test_data = x.to(device)
        test_batch_size = batch_size
        with torch.no_grad():
            for i in range(0, test_data.size(0), test_batch_size):
                end_idx = min(i + test_batch_size, test_data.size(0))
                batch_input = test_data[i:end_idx]
                pred_batch = linear(batch_input).cpu()
                all_preds.append(pred_batch)
            y_pred = torch.cat(all_preds, dim=0)
    else:
        with torch.no_grad():
            #y_pred = linear(x[n_samples_train:n_samples].to(device)).cpu()
            y_pred = linear(x.to(device)).cpu()
    
    y_pred = y_pred.detach()
    
    if sparse:
        # measure the RMSE
        #rmse = torch.sqrt(((y[n_samples_train:n_samples].cpu() - y_pred)**2).mean(dim=0))
        rmse = torch.sqrt(((y.cpu() - y_pred)**2).mean(dim=0))
        return - rmse # so that looking for a lower value is still worse
    else:
        # Simplified R² calculation
        #r_squares = 1 - (((y[n_samples_train:n_samples].cpu() - y_pred)**2).sum(0) / ((y[n_samples_train:n_samples].cpu() - y_mean)**2).sum(0))
        #ssr = ((y[n_samples_train:n_samples].cpu() - y_pred + 1e-9)**2).sum(0)
        #scaling_factor = ((y[n_samples_train:n_samples].cpu() - y_mean + 1e-9)**2).sum(0)

        if any(y_mean == 0):
            if verbose:
                print("   Warning: zeros found in y_mean. Removing samples.")
            non_zero_mean = y_mean != 0
            ssr = ((y.cpu() - y_pred )**2).sum(0)[non_zero_mean]
            scaling_factor = ((y.cpu() - y_mean)**2).sum(0)[non_zero_mean]
        elif torch.any(torch.isnan(y.cpu())) or torch.any(torch.isinf(y.cpu())):
            if verbose:
                print("   Warning: NaN or Inf values found in y. Handling them.")
            # Handle NaN or Inf values
            valid_values = y_mean[~torch.isnan(y_mean) & ~torch.isinf(y_mean)]
            ssr = ((y.cpu() - y_pred)**2).sum(0)[valid_values]
            scaling_factor = ((y.cpu() - y_mean)**2).sum(0)[valid_values]
        else:
            ssr = ((y.cpu() - y_pred)**2).sum(0)
            scaling_factor = ((y.cpu() - y_mean)**2).sum(0)
        r_squares = 1 - ((ssr + 1e-9)/ (scaling_factor + 1e-9))

        #if any(r_squares < 0):
        #    # let me see the true values, predictions, mean, and r_squares
        #    where_neg = r_squares < 0
        #    print(f"Negative R² values found at indices: {where_neg.nonzero(as_tuple=True)}")
        #    print(f"True values: {y[:,where_neg][:,0].cpu()}")
        #    #print(f"Predictions: {y_pred[:,where_neg][:,0]}")
        #    print(f"Mean values: {y_mean[where_neg]}")
        #    print(f"R² values: {r_squares[where_neg]}")

        if r_squares.mean() < 0:
            #print(f"Warning: Negative R² value detected: {r_squares}. This may indicate poor model fit or numerical issues. Trying RMSE instead.")
            #rmse = torch.sqrt(((y[n_samples_train:n_samples].cpu() - y_pred)**2).mean(dim=0))
            #rmse = torch.sqrt(((y.cpu() - y_pred)**2).mean(dim=0))
            #return -rmse
            print(f"Warning: Negative R² value detected: {r_squares.mean()}. This may indicate poor model fit or numerical issues. Trying correlation.")
            # Use correlation as a fallback
            r_squares = torch.tensor([torch.corrcoef(torch.stack((y[:, i].cpu(), y_pred[:, i])))[0, 1] for i in range(y.shape[1])])


        # Clean up
        del linear, optimizer, train_loader, val_loader
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        return r_squares