import glob
import os
import torch
def get_batch_id_per_frame(im_idx_out):
    batch_id_num = []
    num = 0
    for i, n in enumerate(im_idx_out):
        batch_id_num.append(torch.LongTensor([i] * (n.item() + 1 - num)))
        num = n.item() + 1
    return torch.cat(batch_id_num)


def create_new_batch(batch, batch_we_want, boxes_batch, pairs_1, batch_id):
    new_batch = {}
    new_batch["pred_labels"] = batch["pred_labels"][boxes_batch == batch_we_want].clone()
    new_batch["bboxes"] = batch["bboxes"][boxes_batch == batch_we_want].clone()
    new_batch["bboxes"][:, 0] = new_batch["bboxes"][:, 0] - new_batch["bboxes"][:, 0].min()

    new_batch["out_im_idxes"] = batch["out_im_idxes"][batch_we_want] - (
        batch["out_im_idxes"][batch_we_want - 1] if batch_we_want > 0 else 0)
    new_batch["out_im_idxes"] = new_batch["out_im_idxes"][None].clone()
    new_batch["ids"] = batch["ids"][boxes_batch == batch_we_want].clone()
    new_batch["ids"] = new_batch["ids"] - new_batch["ids"].min()

    new_batch["pair_idxes"] = batch["pair_idxes"][pairs_1 == batch_we_want].clone()
    new_batch["pair_idxes"] = new_batch["pair_idxes"] - new_batch["pair_idxes"].min()

    new_batch["pair_human_ids"] = batch["pair_human_ids"][pairs_1 == batch_we_want].clone()
    new_batch["pair_human_ids"] = new_batch["pair_human_ids"] - new_batch["pair_human_ids"].min()

    new_batch["pair_object_ids"] = batch["pair_object_ids"][pairs_1 == batch_we_want].clone()
    new_batch["pair_object_ids"] = new_batch["pair_object_ids"] - new_batch["pair_object_ids"].min()

    new_batch["im_idxes"] = batch["im_idxes"][pairs_1 == batch_we_want].clone()
    new_batch["im_idxes"] = new_batch["im_idxes"] - new_batch["im_idxes"].min()

    if "frames" in batch:
        new_batch["frames"] = batch["frames"][batch_id == batch_we_want].clone()
        new_batch["original_frames"] = batch["original_frames"][batch_we_want]

    if "patch_tokens" in batch:
        new_batch["patch_tokens"] = batch["patch_tokens"][batch_id == batch_we_want].clone()
        new_batch["cls_tokens"] = batch["cls_tokens"][batch_id == batch_we_want].clone()
        new_batch["binary_masks"] = batch["binary_masks"][boxes_batch == batch_we_want].clone()

    windows = batch["windows"]
    windows = windows.T[pairs_1 == batch_we_want]
    windows = windows[:, (windows.sum(axis=0) != 0)].T
    new_batch["windows"] = windows.clone()

    n_interactions = windows.shape[0]

    windows_out = batch["windows_out"]
    windows_out = windows_out.T[pairs_1 == batch_we_want]
    windows_out = windows_out[:, (windows_out.sum(axis=0) != 0)].T
    new_batch["windows_out"] = windows_out.clone()
    new_batch["frames"] = torch.rand(6, 3, 224, 224).to(new_batch["bboxes"].device)

    return new_batch, n_interactions

def map_batch_to_npairs(batch):
    out_im_idxes = batch["out_im_idxes"]
    batch_id = get_batch_id_per_frame(out_im_idxes)
    bboxes_frames = batch["bboxes"][:, 0]

    bboxes_frames_idx = bboxes_frames.to(dtype=torch.long, device=batch_id.device)
    boxes_batch = batch_id[bboxes_frames_idx]  # this has the frames per batch
    pairs_1 = boxes_batch[batch["pair_idxes"][:, 0]]
    pairs_per_batch = [sum(pairs_1 == x) for x in range(max(pairs_1 + 1))]  # this has the number of interactions per batch
    return pairs_per_batch, pairs_1, boxes_batch, batch_id

def do_one_batch(batch, filtered_base_path=None):
    pairs_per_batch, pairs_1, boxes_batch, batch_id = map_batch_to_npairs(batch)

    batch_npairs = {}
    for batch_we_want, n_pairs in enumerate(pairs_per_batch):
        n_frames = (batch_id == batch_we_want).sum()
        if n_frames == 6:
            batch_i, n_interactions = create_new_batch(batch, batch_we_want, boxes_batch, pairs_1, batch_id)
            batch_npairs[n_interactions] = batch_i

    if filtered_base_path is not None:
        existing_filtered_batches_paths = glob.glob(filtered_base_path + "/*.pt")
        existing_filtered_batches = []
        for b in existing_filtered_batches_paths:
            _, b = os.path.split(b)
            b = b.split(".")[0]
            existing_filtered_batches.append(int(b))

        for k, v in batch_npairs.items():
            if k not in existing_filtered_batches:
                print(k, v.keys())
                save_path = filtered_base_path + f"/{k}.pt"
                print(save_path)
                torch.save(v, save_path)
def do_all_batches():
    paths = sorted(glob.glob("./batches/*.pt"))
    filtered_base_path = "./batches_filtered"
    for i, path in enumerate(paths):
        batch = torch.load(path)
        do_one_batch(batch, filtered_base_path)



if __name__ == '__main__':
    #do_all_batches()
    from configs.paths import  project_path
    batch = torch.load(project_path + "/output/sttran_batch.pt")
    do_one_batch(batch)
