import os
import torch
import torch.nn.functional as F
import numpy as np
import logging
from scipy.optimize import linear_sum_assignment


def pairwise_cos_sim(a, b):
    n1 = a.shape[0]
    n2 = b.shape[0]

    a_norm = a / (a.norm(dim=1) + 1e-8)[:, None]
    b_norm = b / (b.norm(dim=1)+ 1e-8)[:, None]
    res = torch.mm(a_norm, b_norm.transpose(0,1))
    
    assert res.shape == (n1, n2)
    return res

def bipartiate_match_video(slots, slot_nums, masks):
    # :arg slots: (sum(slot_nums), D_slot)
    # :arg slot_nums: (#frames)
    # :arg masks: (#frames, H_target, W_target)
    
    F, H, W = masks.shape
    
    slot_acc = 0
    prev_assignment = torch.arange(slot_nums.max(), device=slot_nums.device).long()
    
    for t in range(1, F):
        slot_num_t = slot_nums[t - 1]
        slot_num_t_1 = slot_nums[t]
        
        slots_t = slots[slot_acc: slot_acc + slot_num_t]
        slot_acc = slot_acc + slot_num_t
        
        slots_t_1 = slots[slot_acc: slot_acc + slot_num_t_1]
        
        similarity_matrix = pairwise_cos_sim(slots_t, slots_t_1)
        row_ind, col_ind = linear_sum_assignment(-similarity_matrix.cpu().numpy())
        ab = list(zip(row_ind, col_ind))
        ab_sorted = sorted(ab, key=lambda x: x[0])

        prev_indices = [x[0] for x in ab_sorted]
        indices = [x[1] for x in ab_sorted]
        
        
        if len(indices) < slot_num_t_1:
            assignments = torch.argmax(similarity_matrix, dim=0).tolist()
            
            for j in range(slot_num_t_1):
                if j not in col_ind:
                    prev_indices.append(assignments[j])
                    indices.append(j)

        """
        print(f"Frame {t}")
        print(f"prev_indices: {prev_indices}")
        print(f"indices: {indices}")
        print(f"prev_assignment: {prev_assignment}\n")
        """
        # assert len(indices) == slot_num_t_1
        
        prev_assignment_new = prev_assignment.clone()
        new_mask = torch.zeros(masks[t].shape, device=masks.device, dtype=torch.long)
        
        for i in range(len(indices)):
            t_id = prev_assignment[prev_indices[i]]
            new_mask[masks[t] == indices[i]] = t_id
            
            prev_assignment_new[indices[i]] = t_id
            
        prev_assignment = prev_assignment_new
        masks[t] = new_mask
        
    return masks


def set_logger(logger_path):
    logging.getLogger('PIL').setLevel(logging.WARNING)

    logging.basicConfig(format='%(message)s', level=logging.INFO)
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    fh = logging.FileHandler(logger_path)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.WARN)

    logger.addHandler(fh)
    logger.addHandler(ch)


    return logger