import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import os
import h5py
#from torch.autograd import gradcheck
#from torchviz import make_dot
import math
from torch.utils.data.dataset import random_split


# Define the loss functions for each task
classification_loss_fn = nn.CrossEntropyLoss()
regression_loss_fn = nn.MSELoss()
regressionl1_loss_fn = nn.SmoothL1Loss()

def check_and_report_device(model, optimizer):
    cpu_tensors = []
    for name, param in model.named_parameters():
        if param.device.type == 'cpu':
            cpu_tensors.append((name, 'Model Parameter'))
    
    for group in optimizer.param_groups:
        for p in group['params']:
            if p.device.type == 'cpu':
                cpu_tensors.append(('Optimizer Tensor', 'Optimizer Parameter'))
    
    if cpu_tensors:
        print("Warning: Some tensors are on CPU:")
        for name, tensor_type in cpu_tensors:
            print(f"{tensor_type} -> {name}",flush=True)
    else:
        print("All model and optimizer tensors are on the correct device.",flush=True)



class MultiTaskCNN(nn.Module):
    def __init__(self):
        super(MultiTaskCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        
        self.skip1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1)
        self.skip2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1)
        self.skip3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512, out_features=1024),
            self.activation,
            self.dropout,
            nn.Linear(in_features=1024, out_features=16 * 4)
        )
        
        self.regressor = nn.Sequential(
            nn.Linear(in_features=512, out_features=1024),
            self.activation,
            self.dropout,
            nn.Linear(in_features=1024, out_features=2),
            nn.Softplus()
        )
    
    def forward(self, x):
        x = self.activation(self.bn1(self.conv1(x)))
        skip1 = self.skip1(self.pool(x))  
        x = self.pool(x)
        
        x = self.activation(self.bn2(self.conv2(x)))
        x = x + skip1
        skip2 = self.skip2(self.pool(x))  
        x = self.pool(x)
        
        x = self.activation(self.bn3(self.conv3(x)))
        x = x + skip2
        skip3 = self.skip3(self.pool(x))  
        x = self.pool(x)

        x = self.activation(self.bn4(self.conv4(x)))
        x = x + skip3
        
        x = self.activation(self.bn5(self.conv5(x)))
        x = self.activation(self.bn6(self.conv6(x)))
        
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        
        # Classifier branch
        classification_output = self.classifier(x)
        classification_output = classification_output.view(-1, 16, 4)
        
        # Regressor branch
        regression_output = self.regressor(x)
        
        return classification_output, regression_output




def differentiable_shift_affine(matrix, shifts):
    B, S, H, W = matrix.shape
    device = matrix.device
    
    # Calculate adjusted shifts
    shifts_loc = 2*(10 - shifts)

    #print(shifts_loc)

    # Ensure the matrix is a floating point tensor for interpolation
    matrix = matrix.float()

    # Calculate normalized shifts for the height dimension
    shifts_normalized = (shifts_loc.view(B * S, 1) / H).view(B * S, 1, 1)

    # Create affine transformation matrices, set Y translation correctly
    theta = torch.zeros(B * S, 2, 3, device=device)
    theta[:, 0, 0] = 1  # Set X scaling to 1
    theta[:, 1, 1] = 1  # Set Y scaling to 1
    theta[:, 1, 2] = -1 * shifts_normalized.squeeze()  # Set Y translation, correctly positioned

    # Generate grids from theta and perform affine transformation
    grid = F.affine_grid(theta, [B * S, 1, H, W], align_corners=False).float()
    shifted_matrix = F.grid_sample(matrix.view(B * S, 1, H, W), grid, mode='bilinear', padding_mode='border', align_corners=False)
    
    # Reshape back to the original dimensions
    return shifted_matrix.view(B, S, H, W)

def differentiable_shift_grid(matrix, shifts):
    """
    Shifts a tensor vertically for each slice in each batch using interpolation.
    
    Args:
        matrix (torch.Tensor): Input tensor to be shifted, expected shape (Batch, Slices, Height, Width).
        shifts (torch.Tensor): Tensor of vertical shifts, shape (Batch, Slices). Positive for downward, negative for upward.
    
    Returns:
        torch.Tensor: Tensor with all slices shifted vertically.
    """
    B, S, H, W = matrix.size()
    device = matrix.device

    shifts_loc = 10 - shifts

    # Flatten batch and slices into a single dimension to apply shifts all at once
    flat_matrix = matrix.view(B * S, 1, H, W)
    flat_shifts = shifts_loc.view(-1)

    # Create a grid of coordinates for interpolation
    grid = torch.meshgrid(torch.arange(0, H, device=device), torch.arange(0, W, device=device))
    grid = torch.stack(grid, dim=-1).unsqueeze(0).repeat(B * S, 1, 1, 1).float()

    # Apply the shifts to the grid coordinates
    grid[..., 0] += flat_shifts.view(-1, 1, 1)

    # Normalize grid coordinates to [-1, 1] range
    grid[..., 0] = grid[..., 0] / (H - 1) * 2 - 1
    grid[..., 1] = grid[..., 1] / (W - 1) * 2 - 1

    # Perform interpolation using the grid
    shifted_matrices = F.grid_sample(flat_matrix, grid, mode='bilinear', align_corners=False, padding_mode='border')

    # Reshape back to the original dimensions (Batch, Slices, Height, Width)
    shifted_matrix = shifted_matrices.view(B, S, H, W)

    return shifted_matrix

def differentiable_shift_batch(matrix, shifts):
    """
    Shifts a tensor vertically for each slice in each batch using transformation matrices.
    
    Args:
    matrix (torch.Tensor): Input tensor to be shifted, expected shape (Batch, Slices, Height, Width).
    shifts (torch.Tensor): Tensor of vertical shifts, shape (Batch, Slices). Positive for downward, negative for upward.

    Returns:
    torch.Tensor: Tensor with all slices shifted vertically.
    """
    B, S, H, W = matrix.size()
    device = matrix.device

    matrix = matrix.clone().detach().requires_grad_(True)

    shifts = 10-shifts

    # Flatten batch and slices into a single dimension to apply shifts all at once
    flat_matrix = matrix.view(B * S, H, W)
    flat_shifts = shifts.view(-1)

    # Initialize the transformation matrices for all slices in all batches
    transformation_matrices = torch.eye(H, device=device).repeat(B * S, 1, 1)
    
    # Apply the shifts to the transformation matrices
    for i, shift in enumerate(flat_shifts):
        transformation_matrices[i] = torch.roll(transformation_matrices[i], shifts=int(shift.item()), dims=0)

    # Perform matrix multiplication for shifting
    shifted_matrices = torch.bmm(transformation_matrices, flat_matrix)  # Batch matrix multiplication

    # Reshape back to the original dimensions (Batch, Slices, Height, Width)
    shifted_matrix = shifted_matrices.view(B, S, H, W)

    return shifted_matrix


def check_gradients(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} gradient: {'None' if param.grad is None else 'Present'}")

def print_gradient_norms(model):
    total_norm = 0
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            print(f"Grad norm for {name}: {param_norm.item()}")
            #print(f"Grad norm for {name}: {param.grad.data}")


    total_norm = total_norm ** 0.5
    print(f"Total norm of gradients: {total_norm}")


def classfx(classification_preds,classification_targets):
    classification_loss = 0
    #    print(classification_preds)
    for i in range(16):  # Loop over each pattern case 
        task_pred = classification_preds[:, i, :]  # Shape: [1, 4]
        #print(classification_targets.size(),flush=True)
        #print(task_pred.size())
        task_target = classification_targets[:, i]  # Shape: [1], but needs to be [1] for nn.CrossEntropyLoss
        #print('task pred is ' + str(task_pred) + ' and task target is ' + str(task_target))
        task_loss = classification_loss_fn(task_pred, task_target.long())
        classification_loss += task_loss
        classification_loss /= 16  # Average the loss over all classes

    # TBD - consider weighted cases where we have more incorrect classs, e.g h1 more important than h2/h3

    return classification_loss
def print_hook(module, input, output):
    print(f"{module.__class__.__name__} forward pass")
    print(f"Input: {input}")
    print(f"Output: {output}")
def print_grad(grad):
    print(grad)
def one_hot_encode_targets(targets):
  targets = targets.long()
    
  one_hot_encoded = torch.zeros(targets.size(0), targets.size(1), 4, device=targets.device).scatter_(2, targets.unsqueeze(-1), 1)

  return one_hot_encoded


def register_hooks(model):
    for name, layer in model.named_modules():
        # Skip registering hooks on layers without parameters
        if list(layer.parameters(recurse=False)):
            layer.register_backward_hook(lambda layer, grad_input, grad_output: print(f"{name}: {grad_output}"))



def combined_loss(classification_preds, classification_targets, 
                  regression_preds, regression_targets,trn_sum,pred_sum,alpha,beta,gamma):


    classification_loss = classfx(classification_preds, classification_targets)

    regression_loss = regressionl1_loss_fn(regression_preds, regression_targets)

    custom_loss =  regression_loss_fn(trn_sum,pred_sum)
    #print('class loss is ' + str(alpha*classification_loss) + 'reg loss is ' + str(beta*regression_loss)
    #      + 'custom loss is ' + str(gamma*custom_loss))

    return alpha*classification_loss + beta*regression_loss + gamma*custom_loss


def classfx(classification_preds,classification_targets):
    classification_loss = 0
    #    print(classification_preds)
    for i in range(16):  # Loop over each pattern case 
        task_pred = classification_preds[:, i, :]  # Shape: [1, 4]
        #print(classification_targets.size(),flush=True)
        #print(task_pred.size())
        task_target = classification_targets[:, i]  # Shape: [1], but needs to be [1] for nn.CrossEntropyLoss
        #print('task pred is ' + str(task_pred) + ' and task target is ' + str(task_target))
        task_loss = classification_loss_fn(task_pred, task_target.long())
        classification_loss += task_loss
        classification_loss /= 16  # Average the loss over all classes

    # TBD - consider weighted cases where we have more incorrect classs, e.g h1 more important than h2/h3

    return classification_loss



    
def main():
    # Specify the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    torch.manual_seed(22)

    #torch.set_default_tensor_type(torch.DoubleTensor)


    # Example file names
    #filenames = ['file1.npy', 'file2.npy', 'file3.npy']  # Add your file paths here
    feat_filenames = []
    #lbl_filenames = []

    for channel in range(1,1025):
        for noise in range(1,2):
            for mat_ctr in range(1,33):
                feat_filenames.append('/home/scratch.user_methodology/SL/SL/ILP/synthetic/badsynth/err_matrix_binary_4taps_badsynthchannel_' + str(channel) + '_noise_' + str(noise) + '_ctr_' + str(mat_ctr) + '.mat')
                #lbl_filenames.append('/home/scratch.user_methodology/SL/SL/ILP/synthetic/2lvl/balanced_badsynthchannel_4tap_channel_' + str(channel) + '_noise_' + str(noise) + '_ctr_' + str(mat_ctr) + '_labels.npy')


    # load label data (new) cleaned
    f_lbl = h5py.File('/home/scratch.user_methodology/SL/SL/ILP/synthetic/4lvl/data_cleaned_4lvls.mat')
    label_matrix = f_lbl.get('data_set_clean')

    # Placeholder lists for features and labels
    all_features = []
    all_labels = []
    all_features_test = []
    all_labels_test = []

    p = 80
    i = -1
    for channel in range(1,1025):   
        for noise in range(1,2):
            for mat_ctr in range(1,33):
                i = i+1
                
                # check note zero indexing 
                if os.path.exists(feat_filenames[i]) and (label_matrix[0,mat_ctr-1,channel-1] >=0):
                    #print(feat_filenames[i])
                    #data1 = np.load(feat_filenames[i],allow_pickle=True) 
                    f = h5py.File(feat_filenames[i])
                    data1 = f.get('err_matrix_bin')        
                    data1 = data1[int(64-p/2):int(64+p/2),:,:]
                    data1 = np.transpose(data1, axes=(2, 0, 1))

                    data2 = label_matrix[:,mat_ctr-1,channel-1]
                    #np.load(lbl_filenames[i], allow_pickle=True)
                    features = data1  # Modify according to your data structure
                    features = features.astype(np.float32)
                    labels = data2  # Modify according to your data structure
                    labels = labels.astype(np.float32)
                    #print((data2[8:12]-8)/16)
                    #labels[16:20] = (data2[16:20]-10)/10
                    #print(labels[8:12])

                    # check if labels are balanced 
                    # if(sum(data2[0:16])==8):
                    if(channel >= 950):
                        # Append to the lists
                        all_features_test.append(features)
                        all_labels_test.append(labels)
                    else:
                        all_features.append(features)
                        all_labels.append(labels)
                    #else:
                    #    print('for feat name ' + str(feat_filenames[i]) + ' unbalanced data',flush=True)
                else:
                    print('for feat name ' + str(feat_filenames[i]) + ' file DNE',flush=True)
                    continue


    # Concatenate all features and labels for train/val
    all_features = np.stack(all_features, axis=0)  # Stacks along a new first axis
    all_labels = np.stack(all_labels, axis=0)  # Stacks along a new first axis
    all_features_tensor = torch.from_numpy(all_features).float().to('cuda')  # Convert and transfer in one step
    all_labels_tensor = torch.from_numpy(all_labels).float().to('cuda')
   
    all_features_test = np.stack(all_features_test, axis=0)  # Stacks along a new first axis
    all_labels_test = np.stack(all_labels_test, axis=0)  # Stacks along a new first axis
    all_features_test_tensor = torch.from_numpy(all_features_test).float().to('cuda')  # Convert and transfer in one step
    all_labels_test_tensor = torch.from_numpy(all_labels_test).float().to('cuda')
    

    entries = all_features_tensor.size()
    print(entries)
    print(all_features_tensor.size())
    print(all_labels_tensor.size())

    dataset = TensorDataset(all_features_tensor, all_labels_tensor)
    dataset_test = TensorDataset(all_features_test_tensor, all_labels_test_tensor)

    # create split for train/val sizes 
    train_size = int(entries[0]*0.8)
    val_size =  entries[0]-train_size


    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create DataLoaders for training and validation sets
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    test_loader = DataLoader(dataset_test,batch_size=32,shuffle=False)

    # Example usage
    model = MultiTaskCNN()
    print(model)
    model.to(device)


    optimizer = optim.Adam(model.parameters(), lr=0.0003, betas=(0.9, 0.999), eps=1e-08,weight_decay=1e-5)
    center_point = torch.tensor(10)  # Adjust this based on your actual center point
    center_point = center_point.to(device=device, dtype=torch.float)
    batch_size = 32


    # Training loop
    num_epochs = 1000  # Set the number of epochs
    for epoch in range(num_epochs):
        model.train()

        cumulative_loss = 0
        lcv = 0
        for features, labels in train_loader:
            lcv += 1 
            features, labels = features.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
            features.requires_grad = True 

            #detached_features = features.detach() 

            # get targets 
            classification_targets, reg_targets = labels[:,0:16], labels[:, 16:20]

            if(reg_targets.size(0) <  batch_size):
                break

            # zero gradients 
            optimizer.zero_grad()

            # Forward pass
            classification_preds, regression_preds = model(features)

            # Handle Predictions, create shifts tensor 
            soft_class_preds = F.gumbel_softmax(classification_preds, tau=0.1, hard=False)
            #expanded_pre = torch.stack([center_point - regression_preds[:,0], center_point +  regression_preds[:,0]], dim=1)    
            expanded_pre = torch.stack([center_point - regression_preds[:,0], center_point - regression_preds[:,1], center_point + regression_preds[:,1], center_point +  regression_preds[:,0]], dim=1)
            expanded_pre_expanded = expanded_pre.unsqueeze(1).expand(-1, 16, -1)
            pre_expected_values = torch.sum(soft_class_preds * expanded_pre_expanded, dim=2)        
            predicted = torch.argmax(soft_class_preds, dim=2)

            # Handle Targets, create shifts tensor 
            # Calculate deltas using slicing and subtraction
            delta1 = 10 - reg_targets[:, 0]  # Delta 1 from the first column
            delta2 = 10 - reg_targets[:, 1]
            #print('device is ' + str(delta1.device))
            # Stack deltas into a new tensor (similar to your desired output format)
            regression_targets = torch.stack([delta1,delta2], dim=1)
            #expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point + regression_targets[:,0]], dim=1)  
            expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point - regression_targets[:,1], center_point + regression_targets[:,1], center_point + regression_targets[:,0]], dim=1)
         
            expanded_tra_expanded = expanded_targets.unsqueeze(1).expand(-1,16,-1)
            one_hot_preds = one_hot_encode_targets(classification_targets)
            tra_expected_values = torch.sum(one_hot_preds * expanded_tra_expanded,dim=2)

            # shift matrices 
            shifted_train = differentiable_shift_affine(features,tra_expected_values)
            shifted_pred =  differentiable_shift_affine(features,pre_expected_values)


            # product (AND) over slices 
            err_train = torch.prod(shifted_train,dim=1)
            err_pred = torch.prod(shifted_pred,dim=1)

            # sum over elements 
            trn_sum = torch.sum(err_train, dim=[1, 2])
            pred_sum = torch.sum(err_pred,dim=[1,2])

            #print('device for trn ' + str(trn_sum.device))

            #gamma = 1 - math.exp(-0.00001 * epoch)
            gamma = 0.01
            loss = combined_loss(classification_preds, classification_targets, 
                                 regression_preds, regression_targets,trn_sum,pred_sum,0,0,gamma)

            cumulative_loss += loss.item()

            # Backward pass and optimize
            loss.backward()    
            #check_and_report_device(model, optimizer)
            optimizer.step()

        # Validation step
        model.eval()
        with torch.no_grad():
            # Accumulate validation loss/accuracy here
            #val_accuracy = validate_model(model, val_loader, device)

            total_correct = 0
            reg_total_correct = 0
            total_samples = 0
            reg_total_samples = 0
            mae_sum = 0

            val_custom_mae = []

            for features, labels in val_loader:

                features, labels = features.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
                #features.requires_grad = True 


                # get targets 
                classification_targets, reg_targets = labels[:,0:16], labels[:, 16:20]

                if(reg_targets.size(0) <  batch_size):
                    break


                # Forward pass
                classification_preds, regression_preds = model(features)


                # Handle Predictions, create shifts tensor 
                #expanded_pre = torch.stack([center_point - torch.round(regression_preds[:,0]), center_point +  torch.round(regression_preds[:,0])], dim=1)
                expanded_pre = torch.stack([center_point - regression_preds[:,0], center_point - regression_preds[:,1], center_point + regression_preds[:,1], center_point +  regression_preds[:,0]], dim=1)

                expanded_pre_expanded = expanded_pre.unsqueeze(1).expand(-1, 16, -1)
                #pre_expected_values = torch.sum(soft_class_preds * expanded_pre_expanded, dim=2)        
                predicted = torch.argmax(classification_preds, dim=2)
                one_hot_preds = one_hot_encode_targets(predicted)
                pre_expected_values = torch.sum(one_hot_preds * expanded_pre_expanded,dim=2)

                # Handle Targets, create shifts tensor 
                # Calculate deltas using slicing and subtraction
                delta1 = 10 - reg_targets[:, 0]  # Delta 1 from the first column
                delta2 = 10 - reg_targets[:, 1]
                
                # Stack deltas into a new tensor (similar to your desired output format)
                regression_targets = torch.stack([delta1,delta2], dim=1)
                #expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point + regression_targets[:,0]], dim=1)  
                expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point - regression_targets[:,1], center_point + regression_targets[:,1], center_point + regression_targets[:,0]], dim=1)
              
                expanded_tra_expanded = expanded_targets.unsqueeze(1).expand(-1,16,-1)
                one_hot_targs = one_hot_encode_targets(classification_targets)
                tra_expected_values = torch.sum(one_hot_targs * expanded_tra_expanded,dim=2)


                # shift matrices 
                shifted_train = differentiable_shift_affine(features,tra_expected_values)
                shifted_pred =  differentiable_shift_affine(features,pre_expected_values)

                # product (AND) over slices 
                err_train = torch.prod(shifted_train,dim=1)
                err_pred = torch.prod(shifted_pred,dim=1)

                # sum over elements 
                trn_sum = torch.sum(err_train, dim=[1, 2])
                pred_sum = torch.sum(err_pred,dim=[1,2])

                val_custom_mae.append(torch.mean(trn_sum-pred_sum))

                # Correct predictions
                correct = (predicted == classification_targets).sum().item()

                total_correct += correct
                total_samples += 16*batch_size
                reg_total_samples += regression_targets.size(0)


                rounded_regression_preds = torch.round(regression_preds)
                rounded_regression_labels = torch.round(regression_targets)

                #print('reg comp ' + str(rounded_regression_preds[0,:]) + ' ' + str(rounded_regression_labels[0,:]))
                mae = torch.abs(rounded_regression_preds - rounded_regression_labels).float().mean()

                mae_sum += mae

            # Convert lists of tensors or float to a tensor if not already done
            val_custom_mae_tensor = torch.tensor(val_custom_mae, device='cuda:0')

            # Calculate mean using PyTorch
            mean_val_custom_mae = torch.mean(val_custom_mae_tensor)

            # Since you need these for reporting or further non-GPU computation, you can move them now
            accuracy = (total_correct / total_samples)
            mean_mae = mae_sum / len(val_loader)
            mean_custom_mae = mean_val_custom_mae.cpu().item()  # Convert to CPU and then to a Python scalar


            # test set now 
            test_total_correct = 0
            test_reg_total_correct = 0
            test_total_samples = 0
            test_reg_total_samples = 0
            test_mae_sum = 0

            test_custom_mae = []

            for features, labels in test_loader:

                features, labels = features.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
                #features.requires_grad = True 


                # get targets 
                classification_targets, reg_targets = labels[:,0:16], labels[:, 16:20]

                if(reg_targets.size(0) <  batch_size):
                    break


                # Forward pass
                classification_preds, regression_preds = model(features)


                # Handle Predictions, create shifts tensor 
                #expanded_pre = torch.stack([center_point - torch.round(regression_preds[:,0]), center_point +  torch.round(regression_preds[:,0])], dim=1)
                expanded_pre = torch.stack([center_point - regression_preds[:,0], center_point - regression_preds[:,1], center_point + regression_preds[:,1], center_point +  regression_preds[:,0]], dim=1)
                expanded_pre_expanded = expanded_pre.unsqueeze(1).expand(-1, 16, -1)
                #pre_expected_values = torch.sum(soft_class_preds * expanded_pre_expanded, dim=2)        
                predicted = torch.argmax(classification_preds, dim=2)
                one_hot_preds = one_hot_encode_targets(predicted)
                pre_expected_values = torch.sum(one_hot_preds * expanded_pre_expanded,dim=2)

                # Handle Targets, create shifts tensor 
                # Calculate deltas using slicing and subtraction
                delta1 = 10 - reg_targets[:, 0]  # Delta 1 from the first column
                delta2 = 10 - reg_targets[:, 1]  # Delta 1 from the first column

                # Stack deltas into a new tensor (similar to your desired output format)
                regression_targets = torch.stack([delta1,delta2], dim=1)
                #expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point + regression_targets[:,0]], dim=1) 
                expanded_targets = torch.stack([center_point - regression_targets[:,0], center_point - regression_targets[:,1], center_point + regression_targets[:,1], center_point + regression_targets[:,0]], dim=1)
               
                expanded_tra_expanded = expanded_targets.unsqueeze(1).expand(-1,16,-1)
                one_hot_targs = one_hot_encode_targets(classification_targets)
                tra_expected_values = torch.sum(one_hot_targs * expanded_tra_expanded,dim=2)


                # shift matrices 
                shifted_train = differentiable_shift_affine(features,tra_expected_values)
                shifted_pred =  differentiable_shift_affine(features,pre_expected_values)

                # product (AND) over slices 
                err_train = torch.prod(shifted_train,dim=1)
                err_pred = torch.prod(shifted_pred,dim=1)

                # sum over elements 
                trn_sum = torch.sum(err_train, dim=[1, 2])
                pred_sum = torch.sum(err_pred,dim=[1,2])

                test_custom_mae.append(torch.mean(trn_sum-pred_sum))

                # Correct predictions
                correct = (predicted == classification_targets).sum().item()

                test_total_correct += correct
                test_total_samples += 16*batch_size
                reg_total_samples += regression_targets.size(0)


                rounded_regression_preds = torch.round(regression_preds)
                rounded_regression_labels = torch.round(regression_targets)

                #print('reg comp ' + str(rounded_regression_preds[0,:]) + ' ' + str(rounded_regression_labels[0,:]))
                test_mae = torch.abs(rounded_regression_preds - rounded_regression_labels).float().mean()

                test_mae_sum += test_mae

            # Convert lists of tensors or float to a tensor if not already done
            test_custom_mae_tensor = torch.tensor(test_custom_mae, device='cuda:0')

            # Calculate mean using PyTorch
            mean_test_custom_mae = torch.mean(test_custom_mae_tensor)

            # Since you need these for reporting or further non-GPU computation, you can move them now
            test_accuracy = (test_total_correct / test_total_samples)
            test_mean_mae = test_mae_sum / len(test_loader)
            test_mean_custom_mae = mean_test_custom_mae.cpu().item()  # Convert to CPU and then to a Python scalar


        print(f"Epoch {epoch} - Loss: {cumulative_loss/len(train_loader):.4f}  Validation Accuracy: {accuracy} Mean MAE: {mean_mae} Mean Custom MAE: {mean_custom_mae} Test Accuracy: {test_accuracy} Mean MAE: {test_mean_mae} Mean Custom MAE: {test_mean_custom_mae}  ")
if __name__ == '__main__':
    main()

