import torch
from tqdm import tqdm
from utils.distributed_train_util import process_input, process_output
#for multiview
def validation(val_loader, model, local_rank):
    model.eval()
    total_correct_nums = 0
    total_correct_horizontal_nums = 0
    total_correct_vertical_nums = 0
    for iteration, (inputs, targets) in enumerate(tqdm(val_loader)):
        with torch.no_grad():
            inputs = process_input(inputs, local_rank, non_blocking = True)
            targets = process_output(targets,local_rank, non_blocking = True)
            outputs = model(inputs)
            total_correct_nums += compute_accurate_num(outputs=outputs, targets=targets)
            total_correct_horizontal_nums += compute_horizontal_accurate_num(outputs=outputs, targets=targets)
            total_correct_vertical_nums += compute_vertical_accurate_num(outputs=outputs, targets=targets)
    total_laban_acc = float(total_correct_nums)/len(val_loader)
    horizontal_laban_acc = float(total_correct_horizontal_nums)/len(val_loader)
    vertical_laban_acc = float(total_correct_vertical_nums)/len(val_loader)
    
    acc_dict = {'total_laban_acc': total_laban_acc, 'horizontal_laban_acc': horizontal_laban_acc, 'vertical_laban_acc': vertical_laban_acc}
        
    return acc_dict
            

def compute_accurate_num(outputs, targets):
    total_corrects = 0
    
    output = outputs[f'laban_0']
    target = targets[f'annotation_laban_0']
    mask = targets[f'laban_0_mask']
    
    preds_idx = torch.argmax(output, dim = 2).int()
    targets_idx = torch.argmax(target, dim = 2).int()
    mask_idx = torch.max(mask, dim = 2)[0].int()
    corrects = ((preds_idx == targets_idx) * mask_idx).sum().item()
    # print(torch.sum(mask_idx).item())
    # print('test')
    total_corrects = (corrects/torch.sum(mask_idx).item())

    # total = mask.sum().item()
    
    # accuracy = corrects / total
    return total_corrects/len(outputs)

def decode_horizontal(indices):
    mapping = {0: 'Place', 1: 'Forward', 2: 'Left Forward', 3: 'Left', 4: 'Left Backward',
               5: 'Backward', 6: 'Right Backward', 7: 'Right', 8: 'Right Forward', 9: 'Place', 10: 'Forward',
               11: 'Left Forward', 12: 'Left', 13: 'Left Backward', 14: 'Backward', 15: 'Right Backward',
               16: 'Right', 17: 'Right Forward', 18: 'Forward', 19: 'Left Forward', 20: 'Left', 
               21: 'Left Backward', 22: 'Backward', 23: 'Right Backward', 24: 'Right', 25: 'Right Forward'}
    return [mapping[i.item()] for i in indices]

def decode_vertical(indices):
    mapping = {0: 'Low', 1: 'Low', 2: 'Low', 3: 'Low', 4: 'Low', 5: 'Low', 6: 'Low', 7: 'Low', 8: 'Low', 
               9: 'High', 10: 'High', 11: 'High', 12: 'High', 13: 'High', 14: 'High', 15: 'High', 16: 'High',
               17: 'High', 18: 'Normal', 19: 'Normal', 20: 'Normal', 21: 'Normal', 22: 'Normal', 23: 'Normal',
               24: 'Normal', 25: 'Normal'}
    return [mapping[i.item()] for i in indices]


def compute_horizontal_accurate_num(outputs, targets):
    total_corrects = 0
  
    output = outputs[f'laban_0']
    target = targets[f'annotation_laban_0']
    mask = targets[f'laban_0_mask']  
    
    preds_idx = torch.argmax(output, dim = 2).int()
    preds_flat = preds_idx.view(-1)
    targets_idx = torch.argmax(target, dim = 2).int()
    targets_flat = targets_idx.view(-1)
    mask_idx = torch.max(mask, dim = 2)[0].int()
    mask_flat = mask_idx.view(-1)
    
    valid_index = mask_flat == 1
    preds_directions = decode_horizontal(preds_flat[valid_index])
    targets_directions = decode_horizontal(targets_flat[valid_index])
    corrects = sum(p == t for p, t in zip(preds_directions, targets_directions)) / len(targets_directions)

    total_corrects += corrects
    return total_corrects/len(outputs)

def compute_vertical_accurate_num(outputs, targets):
    total_corrects = 0
    
    output = outputs[f'laban_0']
    target = targets[f'annotation_laban_0']
    mask = targets[f'laban_0_mask']  
    
    preds_idx = torch.argmax(output, dim = 2).int()
    preds_flat = preds_idx.view(-1)
    targets_idx = torch.argmax(target, dim = 2).int()
    targets_flat = targets_idx.view(-1)
    mask_idx = torch.max(mask, dim = 2)[0].int()
    mask_flat = mask_idx.view(-1)
    
    valid_index = mask_flat == 1
    preds_directions = decode_vertical(preds_flat[valid_index])
    targets_directions = decode_vertical(targets_flat[valid_index])
    corrects = sum(p == t for p, t in zip(preds_directions, targets_directions)) / len(targets_directions)

    total_corrects += corrects
    return total_corrects/len(outputs)
        

