#%%
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from lovely_numpy import lo
import lovely_numpy
lovely_numpy.set_config(deeper_width=12)
import lovely_tensors as lt
lt.monkey_patch()
import sys
import os
base_dir = Path(__file__).parent.parent
sys.path.append(str(base_dir))
os.chdir(base_dir)
import torchvision.transforms as transforms
from PIL import Image
from data.Imagenet import Imagenet_Segmentation
import torch
from torch.utils.data import DataLoader
import pandas as pd
# %%


# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])


imagenet_seg_path = Path(os.environ.get("IMAGENET_SEGMENTATION_PATH"))
# ds = Imagenet_Segmentation(imagenet_seg_path)
ds = Imagenet_Segmentation(imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
import h5py
h5py = h5py.File(imagenet_seg_path, 'r')

# %%
# h5py[h5py["/value/gt"][0,0]]
index=0
h5py[h5py['/value/img'][index, 0]] # (3, H, W)
h5py[h5py[h5py['/value/gt'][index, 0]][0,0]] # (H, W)
h5py[h5py['/value/id'][index, 0]] # (14?, 1)
#"".join(chr(i) for i in np.array(h5py[h5py['/value/id'][index, 0]]).flatten())
# h5py['/value/n'][0,0] # 4276
h5py[h5py['/value/target'][index,0]] # (9, 1)
# "".join(chr(i) for i in np.array(h5py[h5py['/value/target'][index, 0]]).flatten())
# %%

n = int(h5py['/value/n'][0,0])
target_lst = []
for idx in range(n):
    target = "".join(chr(i) for i in np.array(h5py[h5py['/value/target'][idx, 0]]).flatten())
    target_lst.append(target)

# %%
df = pd.DataFrame(target_lst, columns=['target'])
df['target'].value_counts().hist()
plt.yscale('log')
plt.title("Histogram of number of samples per classes")
plt.xlabel("Number of samples")
plt.ylabel("Number of classes")
# %%
from nltk.corpus import wordnet as wn
import torchvision
df.groupby("target").apply(lambda group:group.index.tolist())

def batched(iterable, n=1):
    # https://stackoverflow.com/a/8290508
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

batch_size = 10
bi=0
for batch in batched(list(df.groupby("target").apply(lambda group:group.index.tolist()).to_dict().items()), n=batch_size):
    fig, axes = plt.subplots(len(batch),1, figsize=(10,2*len(batch)))
    for ax, (wnid, sample_indices) in zip(axes, batch):
        name = wn.synset_from_pos_and_offset('n', int(wnid.strip("n")))
        print(name.name())
        img_lst = []
        seg_gt_lst = []
        for i in sample_indices:
            img, seg_gt = ds[i]
            img_lst.append(img)
            seg_gt_lst.append(seg_gt)
        img_pt = torch.stack(img_lst, dim=0)
        seg_gt_pt = torch.stack(seg_gt_lst, dim=0)
        seg_gt_pt = seg_gt_pt.to(torch.float32)
        grid = torchvision.utils.make_grid(img_pt, nrow=10)
        grid_seg = torchvision.utils.make_grid(seg_gt_pt.unsqueeze(1), nrow=10, scale_each=True, normalize=True)
        n_imgs = len(img_pt)
        grid = torchvision.utils.make_grid(torch.cat((img_pt, seg_gt_pt.unsqueeze(1).expand(n_imgs,3,224,224)), dim=0), nrow=n_imgs)
        ax.imshow(grid.permute(1,2,0))
        ax.xaxis.set_visible(False)
        ax.set_title(f"{wnid} ({name.name()})")
        ax.yaxis.set_visible(False)
        # plt.figure()
        # plt.imshow(grid_seg.permute(1,2,0))
    fig.tight_layout()
    fig.savefig(f"imagenet_seg_viz/batch_{bi:03}.png")
    plt.close()
    bi+=1
    print(bi)
    # break

# %%
import itertools

for batch in batched(list(df.groupby("target").apply(lambda group:group.index.tolist()).to_dict().items()), n=3):
    for wnid, sample_indices in batch:
        print(wnid)
    print("*******")
    # break
# %%

index = 4271    
plt.imshow(np.transpose(np.array(h5py[h5py['/value/img'][index, 0]]), (2,1,0)))
# plt.imshow(np.transpose(np.array(h5py[h5py[h5py['/value/gt'][index, 0]][0,0]]), (1,0)))

# %%

#import imagenet from torch
from torchvision import datasets
IN = datasets.ImageNet(os.environ.get("IMAGENET_PATH"), split='val')
# IN.classes

df['target'].drop_duplicates().isin(IN.wnids).sum()
# "n06278475" in IN.wnids
         
# %%
from nltk.corpus import wordnet as wn
wn.synset_from_pos_and_offset('n', 1322343) # pup
wn.synset_from_pos_and_offset('n', 6278475) # high-definition_television
# %%


# %%
from tqdm import tqdm
batch_size = 16
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)
seg_size_lst = []
for i, (img, seg_gt) in enumerate(tqdm(dl)):
    # img (B, 3, H, W)
    # seg_gt (B, H, W)
    seg_gt = seg_gt.to(torch.float32)
    # calc segmentation size ratio
    seg_size = seg_gt.sum(dim=(1,2))
    seg_size_ratio = seg_size / seg_gt[0].numel()
    seg_size_lst.append(seg_size_ratio)
# %%
seg_size_pt = torch.cat(seg_size_lst, dim=0) # (N,)
plt.hist(seg_size_pt.numpy(), bins="auto")
plt.xlabel("Segmentation size ratio")
plt.ylabel("Number of samples")
plt.title("Histogram of segmentation size ratio")
display(seg_size_pt)
# %%
