import tqdm
import torch
import numpy as np
from model import SSOD
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, Dataset
from data import data_pipeline, get_imagenet_ood_dataset


# device = torch.device('cuda:0', if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')


def load_model():
    model = SSOD(num_classes=1000)
    model_path = './saved_models/ssod/epoch_2_cls_0.7546_fpr95_ssod_0.4624_fpr95_msp_0.7098_auroc_ssod_0.8888_auroc_msp_0.8080.pth'
    model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
    model.eval()
    return model



def extract_feature(tgt_index=0, thresh=0.9):
    model = load_model()

    # load dataset
    train_loader, val_loader = data_pipeline(batch_size=32)

    ood_set = get_imagenet_ood_dataset(ood_type='iNaturalist')
    ood_loader = DataLoader(ood_set, batch_size=32, shuffle=True, num_workers=3)

    model.eval()
    model = model.to(device)
    id_collect, ood_collect, gap_collect = list(), list(), list()
    with torch.no_grad():
        count = 0
        for x, y in tqdm.tqdm(train_loader):
            select_indices = (y == tgt_index).nonzero().reshape(-1)
            if len(select_indices):
                select_images = x[select_indices].float().to(device)
                select_labels = y[select_indices].long().to(device)
                id_feat, ood_feat, pool_feat = model.extract_id_ood_feature(select_images, select_labels, thresh)
                if id_feat is not None:
                    id_collect.append(id_feat.detach().cpu())
                if ood_feat is not None:
                    ood_collect.append(ood_feat.detach().cpu())
                if pool_feat is not None:
                    gap_collect.append(pool_feat.detach().cpu()) 
                count += 1
            if count > 200:
                break           
        id_collect = torch.cat(id_collect, dim=0).numpy()
        ood_collect = torch.cat(ood_collect, dim=0).numpy()
        gap_collect = torch.cat(gap_collect, dim=0).numpy()
        print(id_collect.shape, ood_collect.shape, gap_collect.shape)
        np.save('./feat_vis/id_collect.npy', id_collect)
        np.save('./feat_vis/ood_collect.npy', ood_collect)
        np.save('./feat_vis/gap_collect.npy', gap_collect)
    
    real_ood_collect = list()
    with torch.no_grad():
        count = 0
        for x in tqdm.tqdm(ood_loader):
            x = x.float().to(device)
            id_feat, ood_feat, pool_feat = model.extract_id_ood_feature(x, y=None, thresh=thresh)
            if pool_feat is not None:
                real_ood_collect.append(pool_feat.detach().cpu()) 
            count += 1
            if count > 50:
                break           
        real_ood_collect = torch.cat(real_ood_collect, dim=0).numpy()
        print(real_ood_collect.shape)
        np.save('./feat_vis/real_ood_collect.npy', real_ood_collect)


def gaussian_sample(feat, num=100):
    mean = np.mean(feat, axis=0)
    std = np.std(feat, axis=0)
    # samples = np.random.normal(loc=mean, scale=std)
    cov = np.diag(std ** 2)
    samples = list()
    for _ in tqdm.tqdm(range(num)):
        sign = np.sign(np.random.uniform(-1, 1, size=len(mean)))
        sample = np.random.multivariate_normal(mean + 30 * std * sign, cov, size=1)
        samples.append(sample)
    samples = np.concatenate(samples, axis=0)
    print(mean.shape, std.shape, samples.shape)
    return samples


def tsne_vis(num=100):
    base_path = './feat_vis/'

    id_feat = np.load(base_path + 'id_collect.npy')[:1000]
    ood_feat = np.load(base_path + 'ood_collect.npy')[:1000]
    gap_feat = np.load(base_path + 'gap_collect.npy')
    real_ood_feat = np.load(base_path + 'real_ood_collect.npy')[:1000]

    # samples = gaussian_sample(gap_feat, num=num)

    # cat = np.concatenate([id_feat, ood_feat, gap_feat, samples])
    cat = np.concatenate([id_feat, ood_feat, gap_feat, real_ood_feat])

    tsne = TSNE(n_components=2)
    # tsne.fit_transform(id_feat)
    # ssod_id = tsne.embedding_

    # tsne.fit_transform(ood_feat)
    # ssod_ood = tsne.embedding_

    tsne.fit_transform(cat)
    ssod_id_ood = tsne.embedding_
    ssod_id = ssod_id_ood[:id_feat.shape[0]]
    ssod_ood = ssod_id_ood[id_feat.shape[0]: id_feat.shape[0] + ood_feat.shape[0]]
    ssod_gap = ssod_id_ood[id_feat.shape[0] + ood_feat.shape[0]: id_feat.shape[0] + ood_feat.shape[0] + gap_feat.shape[0]]
    ssod_real_ood = ssod_id_ood[id_feat.shape[0] + ood_feat.shape[0] + gap_feat.shape[0]:]

    print(ssod_gap.shape)


    color = ['deeppink', 'yellow', 'blue', 'lime']
    base_path = './feat_vis'

    indices = list(range(ssod_id.shape[0]))
    select_index = np.random.choice(indices, size=num, replace=False)
    ssod_id_select = ssod_id[select_index]
    ssod_ood_select = ssod_ood[select_index]
    # ssod_gap = ssod_gap[:num]
    ssod_real_ood = ssod_real_ood[select_index]
    plt.scatter(ssod_id_select[:, 0], ssod_id_select[:, 1], label='SSOD ID', s=30, edgecolor='black', color='mediumseagreen')
    plt.scatter(ssod_ood_select[:, 0], ssod_ood_select[:, 1], label='SSOD OOD', s=30, edgecolor='black', color='salmon')
    plt.scatter(ssod_gap[:, 0], ssod_gap[:, 1], label='GlobalAvgPool', s=30, edgecolor='black', color='blue')
    plt.scatter(ssod_real_ood[:, 0], ssod_real_ood[:, 1], label='iNaturalist', s=30, edgecolor='black', color='gray')
    plt.legend(markerscale=2.2, fontsize='13')
    plt.xticks([])
    plt.yticks([])
    # plt.title('SSOD')
    plt.savefig('%s/SSOD_vis.png' % base_path, dpi=300)
    plt.close()


if __name__ == '__main__':
    # extract_feature(thresh=0.99)
    tsne_vis(num=200)
