import numpy as np
import os
import sys
import torch
import torchvision
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from LEEP import LEEP
from NCE import NCE
import random
import pretrainedmodels 
import datasets_main.datasets as datasets
from argparse import ArgumentParser
import time
import torch.nn as nn
from subsample_selection import sample_pixel

BATCH_SIZE = 8

device = torch.device('cuda:'+str(4))
torch.cuda.set_device(device)

def parse_args():
    parser = ArgumentParser(description="Collecting activations")
    # Data parameters
    parser.add_argument('-ds', '--dataset-name', help='dataset name', default="bdd100k")

    parser.add_argument('-b', '--batch-size', help='minibatch size', default=8, type=int)
    # Image Formating
    parser.add_argument('--resize', help='resize the image', default=128, type=int)

    return parser.parse_args()

config = parse_args()

#preparing the source models
model_suim = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_suim.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=8, kernel_size=(1,1), stride=(1,1))
model_suim.classifier[4] = replaced_last_layer
model_suim.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/suim/fcn_resnet50-03-09-2022-0048/dataset-suim-model-fcn_resnet50-epoch8-1.49451.pt", map_location=torch.device(device)))
model_suim =model_suim.to(device)

model_cityscapes = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_cityscapes.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=34, kernel_size=(1,1), stride=(1,1))
model_cityscapes.classifier[4] = replaced_last_layer
model_cityscapes.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/cityscapes/fcn_resnet50-20-08-2022-0127/dataset-cityscapes-model-fcn_resnet50-epoch54-0.83942.pt", map_location=torch.device(device)))
model_cityscapes = model_cityscapes.to(device)

model_idd = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_idd.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=34, kernel_size=(1,1), stride=(1,1))
model_idd.classifier[4] = replaced_last_layer
model_idd.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/idd/fcn_resnet50-03-09-2022-0618/dataset-idd-model-fcn_resnet50-epoch1-0.04053.pt", map_location=torch.device(device)))
model_idd = model_idd.to(device)

model_pascalvoc = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_pascalvoc.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=21, kernel_size=(1,1), stride=(1,1))
model_pascalvoc.classifier[4] = replaced_last_layer
model_pascalvoc.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/pascalvoc/fcn_resnet50-01-09-2022-0253/dataset-pascalvoc-model-fcn_resnet50-epoch97-0.05592.pt", map_location=torch.device(device)))
model_pascalvoc = model_pascalvoc.to(device)

model_camvid = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model_camvid.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=32, kernel_size=(1,1), stride=(1,1))
model_camvid.classifier[4] = replaced_last_layer
model_camvid.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/camvid/fcn_resnet50-20-08-2022-0114/dataset-camvid-model-fcn_resnet50-epoch35-0.78098.pt", map_location=torch.device(device)))
model_camvid = model_camvid.to(device)

models = {
    'pascalvoc': model_pascalvoc,
    'idd': model_idd,
    'cityscapes': model_cityscapes,
    'suim': model_suim,
    'camvid': model_camvid,
}

src_classes = {
    'pascalvoc': 21,
    'idd': 34,
    'cityscapes': 34,
    'suim': 8,
    'camvid': 32}

target_dataset_name = 'bdd100k'
target_loader, _ = datasets.load_dataset(config)
target_ds = target_loader.dataset

#preparing the source datasets
source_dataset_name = 'camvid'
config.dataset_name = source_dataset_name
source_loader_camvid, _ = datasets.load_dataset(config)
source_ds_camvid = source_loader_camvid.dataset

source_dataset_name = 'idd'
config.dataset_name = source_dataset_name
source_loader_idd, _ = datasets.load_dataset(config)
source_ds_idd = source_loader_idd.dataset

source_dataset_name = 'cityscapes'
config.dataset_name = source_dataset_name
source_loader_cityscapes, _ = datasets.load_dataset(config)
source_ds_cityscapes = source_loader_cityscapes.dataset

source_dataset_name = 'suim'
config.dataset_name = source_dataset_name
source_loader_suim, _ = datasets.load_dataset(config)
source_ds_suim = source_loader_suim.dataset

source_dataset_name = 'pascalvoc'
config.dataset_name = source_dataset_name
source_loader_pascalvoc, _ = datasets.load_dataset(config)
source_ds_pascalvoc = source_loader_pascalvoc.dataset
    

def main():
    print(device)

    ###################################################################################### 
    # Image pair similarity scores average
    block_units = ['block2_unit3', 'block3_unit5', 'block4_unit1']
    lent = len(block_units)

    img_pair_scores_camvid = np.zeros((len(source_ds_camvid),len(target_ds)),dtype=np.float64)
    img_pair_scores_idd = np.zeros((len(source_ds_idd),len(target_ds)),dtype=np.float64)
    img_pair_scores_cityscapes = np.zeros((len(source_ds_cityscapes),len(target_ds)),dtype=np.float64)
    img_pair_scores_suim = np.zeros((len(source_ds_suim),len(target_ds)),dtype=np.float64)
    img_pair_scores_pascalvoc = np.zeros((len(source_ds_pascalvoc),len(target_ds)),dtype=np.float64)


    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/camvid_resnet50_fcn_results/{}_{}_{}.npy'.format('camvid', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_camvid+= temp
    
    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_results/{}_{}_{}.npy'.format('idd', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_idd += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/cityscapes_resnet50_fcn_results/{}_{}_{}.npy'.format('cityscapes', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_cityscapes += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/suim_resnet50_fcn_results/{}_{}_{}.npy'.format('suim', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_suim += temp

    for bu in tqdm(block_units):
        with open('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/pascalvoc_resnet50_fcn_results/{}_{}_{}.npy'.format('pascalvoc', target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            # print(temp.shape)
            img_pair_scores_pascalvoc += temp

    img_pair_scores_camvid /= len(block_units)
    target_scores_camvid = np.mean(img_pair_scores_camvid,axis=0)

    img_pair_scores_idd /= len(block_units)
    target_scores_idd = np.mean(img_pair_scores_idd,axis=0)

    img_pair_scores_cityscapes /= len(block_units)
    target_scores_cityscapes = np.mean(img_pair_scores_cityscapes,axis=0)

    img_pair_scores_suim /= len(block_units)
    target_scores_suim = np.mean(img_pair_scores_suim,axis=0)

    img_pair_scores_pascalvoc /= len(block_units)
    target_scores_pascalvoc = np.mean(img_pair_scores_pascalvoc,axis=0)


    sorted_idxs_camvid = np.argsort(target_scores_camvid)
    sorted_idxs_idd = np.argsort(target_scores_idd)
    sorted_idxs_cityscapes = np.argsort(target_scores_cityscapes)
    sorted_idxs_suim = np.argsort(target_scores_suim)
    sorted_idxs_pascalvoc = np.argsort(target_scores_pascalvoc)

    dict_idxs = {
    'pascalvoc': sorted_idxs_pascalvoc,
    'idd': sorted_idxs_idd,
    'cityscapes': sorted_idxs_cityscapes,
    'suim': sorted_idxs_suim,
    'camvid': sorted_idxs_camvid}
    # print("Sorted idx", sorted_idxs.shape)
    ######################################################################################

    leep_scores_base = np.zeros((len(models)))
    leep_scores_m = np.zeros((len(models)))
    nce_scores_m = np.zeros((len(models)))
    nce_scores_base = np.zeros((len(models)))

    for ds_num, model_name in tqdm(enumerate(models), total=len(models)):
        print(model_name)
        
        source_model = models[model_name]
        source_model.eval()
        source_classes = src_classes[model_name]

        leng = len(target_loader) * config.batch_size
        dummy_dist = np.zeros((leng, 1000, source_classes))
        targets = []

        for batch_idx,(x,y) in enumerate(target_loader):
            outputs = []
            curr_batch_size = len(x)
            x = x.to(device)
            out = source_model(x)
            out = out['out']
            out = torch.nn.functional.softmax(out,dim=1)
            out = out.detach().cpu().numpy()
            for j in range(curr_batch_size):
                y_temp = torch.squeeze(y[j], 0)
                out_temp = np.squeeze(out[j])
                out_temp = out_temp.reshape((128, 128, source_classes))
                locations = sample_pixel(y_temp)
                for p in locations:
                    h, w = p[0], p[1]
                    targets.append(y_temp[h][w])
                    outputs.append(out_temp[:][h][w])
            outputs = np.array(outputs)
            outputs = outputs.reshape(curr_batch_size, 1000, source_classes)
            dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = outputs
        targets = np.array(targets)
        print(targets.shape)
        dummy_dist = dummy_dist.reshape(-1, source_classes)
        print(dummy_dist.shape)

        num_buckets = 1

        #using top 20% of hard samples
        topK = int(len(target_ds)/5)
        for num in range(num_buckets):
            sorted_idxs_temp = dict_idxs[model_name]
            temp_idxs = sorted_idxs_temp[:(num+1) * topK]
            leep_scores_m[ds_num] += LEEP(dummy_dist[temp_idxs],targets[temp_idxs])
            nce_scores_m[ds_num] +=  NCE(np.argmax(dummy_dist[temp_idxs], axis=1), targets[temp_idxs])

        #setting topK=entire dataset translates to base metric
        topK = int(len(target_ds))
        for num in range(num_buckets):
            sorted_idxs_temp = dict_idxs[model_name]
            temp_idxs = sorted_idxs_temp[:(num+1) * topK]
            leep_scores_base[ds_num] += LEEP(dummy_dist[temp_idxs],targets[temp_idxs])
            nce_scores_base[ds_num] +=  NCE(np.argmax(dummy_dist[temp_idxs], axis=1), targets[temp_idxs])

    #finetune the models and add their mIOU in this list
    #note: follow the order of models dictionary
    accuracies = [2.53, 14.90, 14.89, 4.14, 15.07]

    pcc1,p = scipy.stats.pearsonr(nce_scores_m, accuracies)
    pcc2,p = scipy.stats.pearsonr(nce_scores_base, accuracies)
    pcc3,p = scipy.stats.pearsonr(leep_scores_base, accuracies)
    pcc4,p = scipy.stats.pearsonr(leep_scores_m, accuracies)

    print(pcc1)
    print(pcc2)
    print(pcc3)
    print(pcc4)
    
if __name__ == '__main__':
    main()
