from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from ultralytics import YOLO

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import numpy as np
import torch
import torchvision.models as models
import math
import time
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
import os
os.environ['http_proxy'] = 'http://10.129.162.81:32421'
os.environ['https_proxy'] = 'http://10.129.162.81:32421'
import skimage.io
import skimage.transform
from skimage import color


def transform_img_fast(path):
    img = skimage.io.imread(path)
    if len(img.shape) != 3:
        img = skimage.color.gray2rgb(img)
    if img.shape[2] == 4:
        img = color.rgba2rgb(img)
    short_egde = min(img.shape[:2])
    yy = int((img.shape[0] - short_egde) / 2)
    xx = int((img.shape[1] - short_egde) / 2)
    crop_img = img[yy: yy + short_egde, xx: xx + short_egde]
    return (skimage.transform.resize(crop_img, (224, 224)) - 0.5) * 2
def transform_img_fn_fast(paths):
    out = []
    for i, path in enumerate(paths):
        if i % 100 == 0:
            print(i)
        out.append(transform_img_fast(path))
    return np.array(out)
#     return np.array([transform_img_fast(path) for path in paths])

model = YOLO("yolov8n-cls.pt")  # load an official model
model.to("cuda:0")
model.train(data="imagenet", epochs=10, imgsz=224, device="cuda:0")

import copy
from skimage.segmentation import quickshift, mark_boundaries, slic, felzenszwalb
def ShowImageNoAxis(image, boundaries=None, save=None):
    fig = plt.figure()
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    if boundaries is not None:
        ax.imshow(mark_boundaries(image / 2 + 0.5, boundaries))
    else:
        ax.imshow(image / 2 + .5)
    if save is not None:
        plt.savefig(save)
    plt.show()


def ShowImageNoAxisInline(images, texts=None, boundaries=None, save=None):

    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    if(len(images) > 1):
        for i in range(len(images)):
            axes[i].axis('off')
            if(texts != None):
                axes[i].set_title(texts[i])
            if boundaries is not None:
                axes[i].imshow(mark_boundaries(images[i] / 2 + 0.5, boundaries))
            else:
                axes[i].imshow(images[i] / 2 + .5)
            if save is not None:
                plt.savefig(save)
    else:
        for i in range(len(images)):
            axes.axis('off')
            if(texts != None):
                axes.set_title(texts[i])
            if boundaries is not None:
                axes.imshow(mark_boundaries(images[i] / 2 + 0.5, boundaries))
            else:
                axes.imshow(images[i] / 2 + .5)
            if save is not None:
                plt.savefig(save)
    plt.show()

def normalizeRows(data):
    data_min = np.min(data, axis=0)
    data_max = np.max(data, axis=0)
    return (data - data_min) / (data_max - data_min)

def dist(x,y):
    return np.linalg.norm(x - y)

embedding_model = models.resnet50(pretrained=True)
embedding_model = torch.nn.Sequential(*list(embedding_model.children())[:-1])
embedding_model.eval()
embedding_model.to("cuda:0")

preprocess = transforms.Compose([
    # transforms.Resize(256),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def embedding(imgs):
    imgs0 = []
    for i in range(len(imgs)):
        imgs0.append(imgs[i]*255)
    imgs1 = [Image.fromarray(img.astype("uint8")).resize((224, 224)) for img in imgs0]
    
    img_data = torch.stack([preprocess(img) for img in imgs1]).to("cuda:0")
    
    with torch.no_grad():
        features = embedding_model(img_data)
    features = features.view(features.size(0), -1) 
    result_features = features.cpu().numpy()

    del img_data  
    del features  
    torch.cuda.empty_cache()


def predict(images):
    images=images*255
    results=[]
    for image in images:
        results.append(model(image.astype("uint8"),verbose=False))
    # results = model([images[i] for i in range(len(images))])
    probs = []
    for result in results:
        probs.append(result[0].probs.data.cpu().numpy())
    return probs

IMAGE_SHAPE = (224, 224)
BATCH_SIZE = 512

transform = transforms.Compose([
    transforms.Resize(IMAGE_SHAPE),
    transforms.ToTensor(),
])
    return result_features  


data_dir = './datasets/imagenet/train/'

dataset = datasets.ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))
split = int(np.floor(0.1 * len(dataset)))

np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=4)
val_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=val_sampler, num_workers=4)

print('Training images:')
print(f'Number of training samples: {len(train_indices)}')

print('Validation images:')
print(f'Number of validation samples: {len(val_indices)}')

images = []
labels = []

num_samples = len(train_indices)
print(f'Total number of training samples: {num_samples}')

num_batches = int(np.ceil(num_samples / BATCH_SIZE))

for i, (images_, labels_) in enumerate(train_loader):
    images.append(images_.numpy())
    labels.append(labels_.numpy())

images = np.vstack(images)
labels = np.concatenate(labels)

images = np.transpose(images, (0, 2, 3, 1))


test_images = []
test_labels = []

num_samples = len(val_indices)
print(f'Total number of validation samples: {num_samples}')

num_batches = int(np.ceil(num_samples / BATCH_SIZE))

for i, (images_, labels_) in enumerate(val_loader):
    test_images.append(images_.numpy())
    test_labels.append(labels_.numpy())

test_images = np.vstack(test_images)
test_labels = np.concatenate(test_labels)

test_images = np.transpose(test_images, (0, 2, 3, 1))

from collections import defaultdict

label_to_img = defaultdict(list)

for img,label in zip(images,labels):
    label_to_img[label].append(img)

for i in label_to_img.keys():
    print(i,len(label_to_img[i]))
    #get cluster centers

cluster_num=10
label_to_clusters = dict()
label_to_embedding = dict()
label_to_clusters_pred = dict()
pre_traind_sample_id = defaultdict(list)
pre_traind_sample_embedding = dict()
clf1=KMeans(n_clusters=cluster_num,max_iter=50)
for label in label_to_img.keys():
    # alldata=[]
    # for x in label_to_img[label]:
    #     alldata.append(embedding(x))
    #     if(len(alldata)%100 == 0):
    #         print(len(alldata)+"/"+len(label_to_img[label]))
    alldata=embedding(label_to_img[label])
    print("label:",label,alldata.shape)

    pca = PCA(n_components=10)

    label_to_embedding[label] = copy.deepcopy(alldata)

    alldata=pca.fit_transform(alldata)
    scaler = StandardScaler()
    alldata_scaled = scaler.fit_transform(alldata)
    # nor_alldata=normalizeRows(alldata)
    #alldata=np.vstack((dataset.train,dataset.test))
    clf1.fit(alldata_scaled)
    #print(pca2d.fit_transform(clf1.cluster_centers_))
    #print(clf2.cluster_centers_)

    label_to_clusters_pred[label] = clf1.labels_
    label_to_clusters[label] = clf1.cluster_centers_
    for center in label_to_clusters[label]:
        bestid = 0
        for i in range(len(alldata_scaled)):
            if(dist(center,alldata_scaled[i]) < dist(center,alldata_scaled[bestid])):
                bestid = i
        pre_traind_sample_id[label].append(bestid)
    pre_traind_sample_embedding[label] = label_to_embedding[label][pre_traind_sample_id[label]]

    '''
    fig, ax = plt.subplots(figsize=(12,6))
    plt.scatter(alldata[:, 0], alldata[:, 1], c=label_pred)
    plt.show()
    '''
for label in label_to_img.keys():
    print("label: ",label)
    print(pre_traind_sample_id[label])


from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
sam = sam_model_registry["vit_h"](checkpoint="./segment_anything/sam_vit_h_4b8939.pth").to('cuda:1')
sam.eval()
mask_generator = SamAutomaticMaskGenerator(sam)

def segmentation_fn(image):

    masks = mask_generator.generate(image)

    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    del masks
    masks_result = [sorted_masks[i]['segmentation'] for i in range(len(sorted_masks))]
    masks_size = [sorted_masks[i]['area'] for i in range(len(sorted_masks))]
    del sorted_masks

    def get_overlap_area(block1, block2):
        overlap_mask = np.logical_and(block1, block2)
        return np.sum(overlap_mask)

    def get_segment_result(threshold = 0.9):
        class_labels = np.full(masks_result[0].shape, 0)
        to_delete = []
        for i in range(len(masks_result)):
            id = i
            for j in range(i):
                if(j in to_delete):
                    continue
                # print(i,j, get_overlap_area(masks_result[i],masks_result[j]) / masks_size[i])
                if(get_overlap_area(masks_result[i],masks_result[j]) / masks_size[i] >= threshold):
                    id = j
                    to_delete.append(i)
                    break
            # print(i, id)
            class_labels[masks_result[i]] = id+1

        unique_labels, inverse_indices = np.unique(class_labels, return_inverse=True)
        # print(unique_labels)
        # print(inverse_indices)

        class_labels = inverse_indices.reshape(class_labels.shape)

        return class_labels
    
    torch.cuda.empty_cache()
    
    return get_segment_result()

from anchorx import anchor_image
#/home/marcotcr/datasets/openimages/train contains a lot of arbitrary images - in this case, from openimages
explainer = anchor_image.AnchorImage(distribution_path='./datasets/imagenet/train_anchor',
                               transform_img_fn=transform_img_fn_fast, n=5000, segmentation_fn=segmentation_fn)


pre_traind_sample_data = defaultdict(list)
pre_traind_sample_res = defaultdict(list)
for label in label_to_img.keys():
    print("label:",label)
    count=0
    for id in pre_traind_sample_id[label]:
        print(count,"/",len(pre_traind_sample_id[label]))
        pre_train_sample = label_to_img[label][id]
        segments, exp, pretrain_res = explainer.explain_instance(pre_train_sample, predict, threshold=0.95, batch_size=50,
                                            tau=0.20, verbose=True, min_shared_samples=200, beam_size=2)
        pre_traind_sample_data[label].append(pre_train_sample)
        pre_traind_sample_res[label].append(pretrain_res)
        count+=1

# import json


# data = {
#     'id': pre_traind_sample_id,
#     'embedding': pre_traind_sample_embedding,
#     'data': pre_traind_sample_data,
#     'res': pre_traind_sample_res
# }

# with open('exp_YOLO_imagenet.json', 'w') as f:
#     json.dump(data, f)

# import csv
# data = list(zip(pre_traind_sample_id, pre_traind_sample_embedding, pre_traind_sample_data, pre_traind_sample_res))

# with open('exp_YOLO_imagenet.csv', 'w', newline='') as f:
#     writer = csv.writer(f)
#     writer.writerow(['id', 'embedding', 'data', 'res']) 
#     writer.writerows(data)

import pickle

data = (pre_traind_sample_id, pre_traind_sample_embedding, pre_traind_sample_data, pre_traind_sample_res)

with open('exp_YOLO_imagenet_10.pkl', 'wb') as f:
    pickle.dump(data, f)