import os, random, time, copy
from skimage import io, transform
import numpy as np
import os.path as path
import scipy.io as sio
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import roc_curve, roc_auc_score
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler 
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import models, transforms



# 
# ref:  https://github.com/lwneal/counterfactual-open-set/blob/master/generativeopenset/evaluation.py




class CustomizedPoolList(nn.Module):
    def __init__(self, poolSizeList=[32,32,16,8,4], poolType='max'):
        super(CustomizedPoolList, self).__init__()
        
        self.poolSizeList = poolSizeList
        self.poolType = poolType
        #self.linearLayers = OrderedDict()
        self.relu = nn.ReLU()
        #self.mnist_clsnet = nn.ModuleList(list(self.linearLayers.values()))
        
    def forward(self, feaList):
        x = []
        if self.poolType=='max':
            for i in range(len(self.poolSizeList)):
                if self.poolSizeList[i]>0:
                    x += [F.max_pool2d(feaList[i], self.poolSizeList[i])]
        elif self.poolType=='avg':
            for i in range(len(self.poolSizeList)):
                if self.poolSizeList[i]>0:
                    x += [F.avg_pool2d(feaList[i], self.poolSizeList[i])]
        
        x = torch.cat(x, 1)  
        x = x.view(x.shape[0], -1)
        return x



class weightedL1Loss(nn.Module):
    def __init__(self, weight=1):
        # mean over all
        super(weightedL1Loss, self).__init__()        
        self.loss = nn.L1Loss()
        self.weight = weight
        
    def forward(self, inputs, target): 
        lossValue = self.weight * self.loss(inputs, target)
        return lossValue 




    

def evaluate_openset(scores_closeset, scores_openset):    
    y_true = np.array([0] * len(scores_closeset) + [1] * len(scores_openset))
    y_discriminator = np.concatenate([scores_closeset, scores_openset])
    auc_d, roc_to_plot = plot_roc(y_true, y_discriminator, 'Discriminator ROC')
    return auc_d, roc_to_plot


def plot_roc(y_true, y_score, title="Receiver Operating Characteristic", **options):
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    auc_score = roc_auc_score(y_true, y_score)
    roc_to_plot = {'tp':tpr, 'fp':fpr, 'thresh':thresholds, 'auc_score':auc_score}
    #plot = plot_xy(fpr, tpr, x_axis="False Positive Rate", y_axis="True Positive Rate", title=title)
    #if options.get('roc_output'):
    #    print("Saving ROC scores to file")
    #    np.save(options['roc_output'], (fpr, tpr))
    #return auc_score, plot, roc_to_plot
    return auc_score, roc_to_plot


def plot_xy(x, y, x_axis="X", y_axis="Y", title="Plot"):
    df = pd.DataFrame({'x': x, 'y': y})
    plot = df.plot(x='x', y='y')

    plot.grid(b=True, which='major')
    plot.grid(b=True, which='minor')
    
    plot.set_title(title)
    plot.set_ylabel(y_axis)
    plot.set_xlabel(x_axis)
    return plot


def backup_Weibull():
    print("Weibull: computing features for all correctly-classified training data")
    activation_vectors = {}
    for images, labels in dataloader_train_closeset:         
        images = images.to(device)
        labels = labels.type(torch.long).view(-1).to(device)

        embFeature = encoder(images)
        logits = clsModel(embFeature)
        #logits =  F.softmax(logits, dim=1)

        correctly_labeled = (logits.data.max(1)[1] == labels)
        labels_np = labels.cpu().numpy()
        logits_np = logits.data.cpu().numpy()
        for i, label in enumerate(labels_np):
            if not correctly_labeled[i]:
                continue
            if label not in activation_vectors:
                activation_vectors[label] = []
            activation_vectors[label].append(logits_np[i])

    print("Computed activation_vectors for {} known classes".format(len(activation_vectors)))
    for class_idx in activation_vectors:
        print("Class {}: {} images".format(class_idx, len(activation_vectors[class_idx])))    
        
    # Compute a mean activation vector for each class
    print("Weibull computing mean activation vectors...")
    mean_activation_vectors = {}
    for class_idx in activation_vectors:
        mean_activation_vectors[class_idx] = np.array(activation_vectors[class_idx]).mean(axis=0)        
        
    WEIBULL_TAIL_SIZE = 20
    # Initialize one libMR Wiebull object for each class
    print("Fitting Weibull to distance distribution of each class")
    weibulls = {}
    for class_idx in activation_vectors:
        distances = []
        mav = mean_activation_vectors[class_idx]
        for v in activation_vectors[class_idx]:
            distances.append(np.linalg.norm(v - mav))
        mr = libmr.MR()
        tail_size = min(len(distances), WEIBULL_TAIL_SIZE)
        mr.fit_high(distances, tail_size)
        weibulls[class_idx] = mr
        print("Weibull params for class {}: {}".format(class_idx, mr.get_params()))
        
        
    # Apply Weibull score to every logit
    weibull_scores_closeset = []
    logits_closeset = []
    classes = activation_vectors.keys()
    for images, labels in dataloader_test_closeset:
        images = images.to(device)
        labels = labels.type(torch.long).view(-1).to(device)    
        embFeature = encoder(images)
        batch_logits = clsModel(embFeature).data.cpu().numpy()    
        batch_weibull = np.zeros(shape=batch_logits.shape)
        for activation_vector in batch_logits:
            weibull_row = np.ones(len(classes))
            for class_idx in classes:
                mav = mean_activation_vectors[class_idx]
                dist = np.linalg.norm(activation_vector - mav)
                weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist)
            weibull_scores_closeset.append(weibull_row)
            logits_closeset.append(activation_vector)

    weibull_scores_closeset = np.array(weibull_scores_closeset)
    logits_closeset = np.array(logits_closeset)
    openmax_scores_closeset = -np.log(np.sum(np.exp(logits_closeset * weibull_scores_closeset), axis=1))        
        

    # Apply Weibull score to every logit
    weibull_scores_openset = []
    logits_openset = []
    classes = activation_vectors.keys()
    for images, labels in dataloader_test_openset:
        images = images.to(device)
        labels = labels.type(torch.long).view(-1).to(device)    
        embFeature = encoder(images)
        batch_logits = clsModel(embFeature).data.cpu().numpy()    
        batch_weibull = np.zeros(shape=batch_logits.shape)
        for activation_vector in batch_logits:
            weibull_row = np.ones(len(classes))
            for class_idx in classes:
                mav = mean_activation_vectors[class_idx]
                dist = np.linalg.norm(activation_vector - mav)
                weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist)
            weibull_scores_openset.append(weibull_row)
            logits_openset.append(activation_vector)

    weibull_scores_openset = np.array(weibull_scores_openset)
    logits_openset = np.array(logits_openset)
    openmax_scores_openset = -np.log(np.sum(np.exp(logits_openset * weibull_scores_openset), axis=1))        