from cmath import phase
from email.mime import image
import os
from pickletools import uint8
import numpy as np
import torch, torchvision
from tqdm import tqdm
import cv2
import pandas as pd

import random
from torchvision import transforms
from PIL import Image

# random.seed(42)

def get_spatial_fragments(video,fragments_h=7,fragments_w=7,fsize_h=32,fsize_w=32,aligned=32,
                          nmini_patches=1,random=False,fallback_type="upsample",
):
    size_h, size_w = fragments_h * fsize_h, fragments_w * fsize_w

    ## situation for images
    if video.shape[1] == 1:
        aligned = 1

    dur_t, res_h, res_w = video.shape[-3:]
    ratio = min(res_h / size_h, res_w / size_w)
    if fallback_type == "upsample" and ratio < 1:
        ovideo = video
        video = torch.nn.functional.interpolate(video / 255.0, scale_factor=1 / ratio, mode="bilinear")
        video = (video * 255.0).type_as(ovideo)

    assert dur_t % aligned == 0, "Please provide match vclip and align index"
    size = size_h, size_w

    ## make sure that sampling will not run out of the picture
    hgrids = torch.LongTensor([min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)])
    wgrids = torch.LongTensor([min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)])
    hlength, wlength = res_h // fragments_h, res_w // fragments_w

    if random:
        print("This part is deprecated. Please remind that.")
        if res_h > fsize_h:
            rnd_h = torch.randint(res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned))
        else:
            rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
        if res_w > fsize_w:
            rnd_w = torch.randint(res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned))
        else:
            rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
    else:
        if hlength > fsize_h:
            rnd_h = torch.randint(hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned))
        else:
            rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
        if wlength > fsize_w:
            rnd_w = torch.randint(wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned))
        else:
            rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()

    target_video = torch.zeros(video.shape[:-2] + size).to(video.device)


    for i, hs in enumerate(hgrids):
        for j, ws in enumerate(wgrids):
            for t in range(dur_t // aligned):
                t_s, t_e = t * aligned, (t + 1) * aligned
                h_s, h_e = i * fsize_h, (i + 1) * fsize_h
                w_s, w_e = j * fsize_w, (j + 1) * fsize_w
                if random:
                    h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
                    w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
                else:
                    h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
                    w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
                target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[:, t_s:t_e, h_so:h_eo, w_so:w_eo]
    return target_video

def get_multiview_fragments(imgs, fragments=7, fsize=32, num_view=6):
    """
    Input:  imgs, [num_view, 3, 224, 224]
    Output: QMM, [1,3,224,224] = [1,3,32*7,32*7]
    """
    assert imgs.shape[-2] == 224 and imgs.shape[-1] == 224
    # mean = torch.FloatTensor([123.675, 116.28, 103.53])
    # std = torch.FloatTensor([58.395, 57.12, 57.375])
    num_QMM = 1
    mini_patch_maps = torch.zeros((num_view,*imgs[0].shape))
    QMM = torch.zeros((num_QMM,*imgs[0].shape))

    for view in range(num_view):
        img = imgs[view,...].unsqueeze(1)
        ifrag = get_spatial_fragments(img,fragments,fragments,fsize,fsize)
        # ifrag = (((ifrag.permute(1, 2, 3, 0) - mean) / std).squeeze(0).permute(2, 0, 1))
        ifrag = (ifrag.permute(1, 2, 3, 0).squeeze(0).permute(2, 0, 1))
        # print(ifrag.shape)
        mini_patch_maps[view] = ifrag

    num_pick_mini_patches = fragments**2 // num_view
    for view in range(num_view):
        mini_patches_list = list((range(fragments**2)))
        random.shuffle(mini_patches_list)
        mini_patch_id = 0
        # pick the rest mini_patches from the last viewpoint
        if view == num_view - 1: 
            num_pick_mini_patches = fragments**2 - (num_view - 1) * num_pick_mini_patches
        # assign the picked mini_patches to QMM
        for pick_id in range(num_pick_mini_patches):
            # get the merge frag id
            QMM_id = view * fragments**2 // num_view + pick_id
            QMM_h = (QMM_id % fragments)*fsize
            QMM_w = (QMM_id // fragments)*fsize
            # get the pick mini_patch_id
            for merge_id in range(num_QMM):
                # if the pick_mini_patch_id exceeds fragments**2 (limits of the mini_patches_list), the last frag will be repeatedly assigned
                while mini_patch_id < fragments**2:
                    pick_mini_patch_id = mini_patches_list[mini_patch_id]
                    pick_frag_h = (pick_mini_patch_id % fragments)*fsize
                    pick_frag_w = (pick_mini_patch_id // fragments)*fsize
                    mini_patch_id = mini_patch_id + 1
                    # ignore the blank patch 
                    if torch.mean(mini_patch_maps[view][:,  pick_frag_h : pick_frag_h + fsize, pick_frag_w : pick_frag_w + fsize ]) <  2.2: 
                        break
                QMM[merge_id][:, QMM_h : QMM_h + fsize, QMM_w : QMM_w + fsize ] = \
                    mini_patch_maps[view][:, pick_frag_h : pick_frag_h + fsize, pick_frag_w : pick_frag_w + fsize ] 
                
    return QMM