import numpy as np
import os
import sys
import torch
import torchvision
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from LEEP import LEEP
from NCE import NCE
import random
# import caltech101.caltech_dataset as caltech101
import pretrainedmodels 
import datasets_main.datasets as datasets
from argparse import ArgumentParser
import time
import torch.nn as nn
from subsample_selection import sample_pixel
from gbc_tf import get_gbc_score

BATCH_SIZE = 8

device = torch.device('cuda:'+str(4))
torch.cuda.set_device(device)

def parse_args():
    parser = ArgumentParser(description="Collecting activations")
    # Data parameters
    parser.add_argument('-ds', '--dataset-name', help='dataset name', default="bdd100k")

    parser.add_argument('-b', '--batch-size', help='minibatch size', default=8, type=int)
    # Image Formating
    parser.add_argument('--resize', help='resize the image', default=128, type=int)

    return parser.parse_args()

config = parse_args()

#preparing the source models
model_suim = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_suim.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=8, kernel_size=(1,1), stride=(1,1))
model_suim.classifier[4] = replaced_last_layer
model_suim.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/suim/fcn_resnet50-03-09-2022-0048/dataset-suim-model-fcn_resnet50-epoch8-1.49451.pt", map_location=torch.device(device)))
model_suim =model_suim.to(device)


model_cityscapes = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_cityscapes.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=34, kernel_size=(1,1), stride=(1,1))
model_cityscapes.classifier[4] = replaced_last_layer
model_cityscapes.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/cityscapes/fcn_resnet50-20-08-2022-0127/dataset-cityscapes-model-fcn_resnet50-epoch54-0.83942.pt", map_location=torch.device(device)))
model_cityscapes = model_cityscapes.to(device)


model_idd = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_idd.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=34, kernel_size=(1,1), stride=(1,1))
model_idd.classifier[4] = replaced_last_layer
model_idd.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/idd/fcn_resnet50-03-09-2022-0618/dataset-idd-model-fcn_resnet50-epoch1-0.04053.pt", map_location=torch.device(device)))
model_idd = model_idd.to(device)

model_pascalvoc = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_pascalvoc.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=21, kernel_size=(1,1), stride=(1,1))
model_pascalvoc.classifier[4] = replaced_last_layer
model_pascalvoc.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/pascalvoc/fcn_resnet50-01-09-2022-0253/dataset-pascalvoc-model-fcn_resnet50-epoch97-0.05592.pt", map_location=torch.device(device)))
model_pascalvoc = model_pascalvoc.to(device)

model_camvid = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_camvid.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=32, kernel_size=(1,1), stride=(1,1))
model_camvid.classifier[4] = replaced_last_layer
model_camvid.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/camvid/fcn_resnet50-20-08-2022-0114/dataset-camvid-model-fcn_resnet50-epoch35-0.78098.pt", map_location=torch.device(device)))
model_camvid = model_camvid.to(device)

models = {
    'pascalvoc': model_pascalvoc,
    'idd': model_idd,
    'cityscapes': model_cityscapes,
    'suim': model_suim,
    'camvid': model_camvid,
}

src_classes = {
    'pascalvoc': 21,
    'idd': 34,
    'cityscapes': 34,
    'suim': 8,
    'camvid': 32}

target_dataset_name = 'bdd100k'
target_loader, _ = datasets.load_dataset(config)
target_ds = target_loader.dataset

#preparing the source datasets
source_dataset_name = 'camvid'
config.dataset_name = source_dataset_name
source_loader_camvid, _ = datasets.load_dataset(config)
source_ds_camvid = source_loader_camvid.dataset

source_dataset_name = 'idd'
config.dataset_name = source_dataset_name
source_loader_idd, _ = datasets.load_dataset(config)
source_ds_idd = source_loader_idd.dataset

source_dataset_name = 'cityscapes'
config.dataset_name = source_dataset_name
source_loader_cityscapes, _ = datasets.load_dataset(config)
source_ds_cityscapes = source_loader_cityscapes.dataset

source_dataset_name = 'suim'
config.dataset_name = source_dataset_name
source_loader_suim, _ = datasets.load_dataset(config)
source_ds_suim = source_loader_suim.dataset

source_dataset_name = 'pascalvoc'
config.dataset_name = source_dataset_name
source_loader_pascalvoc, _ = datasets.load_dataset(config)
source_ds_pascalvoc = source_loader_pascalvoc.dataset

def main():
    print(device)

    ###################################################################################### 
    # Image pair similarity scores average
    block_units = ['block2_unit3', 'block3_unit5', 'block4_unit1']
    lent = len(block_units)

    img_pair_scores_camvid = np.zeros((len(source_ds_camvid),len(target_ds)),dtype=np.float64)
    img_pair_scores_idd = np.zeros((len(source_ds_idd),len(target_ds)),dtype=np.float64)
    img_pair_scores_cityscapes = np.zeros((len(source_ds_cityscapes),len(target_ds)),dtype=np.float64)
    img_pair_scores_suim = np.zeros((len(source_ds_suim),len(target_ds)),dtype=np.float64)
    img_pair_scores_pascalvoc = np.zeros((len(source_ds_pascalvoc),len(target_ds)),dtype=np.float64)


    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/camvid_resnet50_fcn_results/{}_{}_{}.npy'.format('camvid', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_camvid+= temp
    
    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_results/{}_{}_{}.npy'.format('idd', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_idd += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/cityscapes_resnet50_fcn_results/{}_{}_{}.npy'.format('cityscapes', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_cityscapes += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/suim_resnet50_fcn_results/{}_{}_{}.npy'.format('suim', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_suim += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/pascalvoc_resnet50_fcn_results/{}_{}_{}.npy'.format('pascalvoc', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_pascalvoc += temp

    img_pair_scores_camvid /= len(block_units)
    target_scores_camvid = np.mean(img_pair_scores_camvid,axis=0)

    img_pair_scores_idd /= len(block_units)
    target_scores_idd = np.mean(img_pair_scores_idd,axis=0)

    img_pair_scores_cityscapes /= len(block_units)
    target_scores_cityscapes = np.mean(img_pair_scores_cityscapes,axis=0)

    img_pair_scores_suim /= len(block_units)
    target_scores_suim = np.mean(img_pair_scores_suim,axis=0)

    img_pair_scores_pascalvoc /= len(block_units)
    target_scores_pascalvoc = np.mean(img_pair_scores_pascalvoc,axis=0)


    sorted_idxs_camvid = np.argsort(target_scores_camvid)
    sorted_idxs_idd = np.argsort(target_scores_idd)
    sorted_idxs_cityscapes = np.argsort(target_scores_cityscapes)
    sorted_idxs_suim = np.argsort(target_scores_suim)
    sorted_idxs_pascalvoc = np.argsort(target_scores_pascalvoc)

    dict_idxs = {
    'pascalvoc': sorted_idxs_pascalvoc,
    'idd': sorted_idxs_idd,
    'cityscapes': sorted_idxs_cityscapes,
    'suim': sorted_idxs_suim,
    'camvid': sorted_idxs_camvid}
    # print("Sorted idx", sorted_idxs.shape)
    ######################################################################################



    gbc_scores = np.zeros((len(models)))
    gbc_scores_base = np.zeros((len(models)))

    for ds_num, model_name in tqdm(enumerate(models), total=len(models)):
        print(model_name)
        
        source_model = models[model_name]
        source_model.eval()
        source_classes = src_classes[model_name]

        leng = len(target_loader) * config.batch_size
        dummy_dist = np.zeros((leng, 1000, source_classes))
        targets = []


        gbc_output_model = source_model.backbone
        outputs = []
        for batch_idx,(x,y) in enumerate(target_loader):
            
            curr_batch_size = len(x)
            x = x.to(device)

            out = gbc_output_model(x)['out']
            out = out.detach().cpu().numpy()
            y = y.float()

            y_curr = torch.nn.functional.interpolate(y, size=(out.shape[-1], out.shape[-1]))
            y_curr = y_curr.long()
            for j in range(curr_batch_size):
                y_temp = np.squeeze(y_curr[j])
                y_temp = y_temp.reshape(256)
                for i_k in range(y_temp.shape[0]):
                    targets.append(y_temp[i_k])
                out_curr = np.squeeze(out[j])
                #reshaping the output to match the latent dimension
                out_curr = out_curr.reshape(16, 16, 2048)
                for i_k in range(out_curr.shape[0]):
                    for j_k in range(out_curr.shape[1]):
                        outputs.append(out_curr[i_k][j_k])

        targets = np.array(targets)
        print(targets.shape)
        outputs = np.array(outputs)
        print(outputs.shape)

        num_buckets = 1
        topK = int(len(target_ds)/30)
        for num in range(num_buckets):
            sorted_idxs_temp = dict_idxs[model_name]
            temp_idxs = sorted_idxs_temp[:(num+1) * topK]
            gbc_scores[ds_num] += get_gbc_score(outputs[temp_idxs], targets[temp_idxs], 'spherical')
            # gbc_scores[ds_num] += GBC(outputs[temp_idxs], targets[temp_idxs], 34)

                 
        #taking entire dataset translates to the base mtric         
        topK = int(len(target_ds))
        for num in range(num_buckets):
            sorted_idxs_temp = dict_idxs[model_name]
            temp_idxs = sorted_idxs_temp[:(num+1) * topK]
            gbc_scores_base[ds_num] += get_gbc_score(outputs[temp_idxs], targets[temp_idxs], 'spherical')
            # gbc_scores_base[ds_num] += GBC(outputs[temp_idxs], targets[temp_idxs], 34)

    print(gbc_scores)
    print(gbc_scores_base)

    accuracies = [2.53, 14.90, 14.89, 4.14, 15.07]
    
    pcc1,p = scipy.stats.pearsonr(gbc_scores, accuracies)
    pcc2,p = scipy.stats.pearsonr(gbc_scores_base, accuracies)
    

    print(pcc1)
    print(pcc2)

if __name__ == '__main__':
    main()