import torch.nn as nn
#import os
#import sklearn
from sklearn.covariance import EmpiricalCovariance
from scipy.spatial.distance import pdist, cdist, squareform
import torch.nn.functional as F
import utils
import torch
import math
#import sys

from sklearn.cluster import MiniBatchKMeans
from sklearn.cluster import KMeans
#from center_loss import CenterLoss


__all__ = ['GenericLossFirstPart', 'GenericLossSecondPart']


class GenericLossFirstPart(nn.Module):
    """Replaces classifier layer"""
    def __init__(self, in_features, out_features, type):
        super(GenericLossFirstPart, self).__init__()
        self.type = type
        self.inference_transform = False # Apply or not distance transformation during inference...
        self.inference_learn = "NO" # Define which scaling will be used during inference...
        self.in_features = in_features
        self.out_features = out_features
        self.metrics_evaluation_mode = False

        if self.type.startswith("soft"):
            self.weights = nn.Parameter(torch.Tensor(out_features, in_features))
            self.bias = nn.Parameter(torch.Tensor(out_features))
            #self.alpha = float(self.type.split("_")[0].strip("soft"))
            nn.init.uniform_(self.weights, a=-math.sqrt(3/self.in_features), b=math.sqrt(3/self.in_features))
            nn.init.uniform_(self.bias, a=-math.sqrt(3/self.in_features), b=math.sqrt(3/self.in_features))
            print("init softmax!!!")


        elif self.type.startswith("iso"):
            #self.register_parameter('bias', None) # Fix this!!!
            self.weights = nn.Parameter(torch.Tensor(out_features, in_features))
            #self.alpha = float(self.type.split("_")[0].strip("iso")) ####### NEW!!!!

            self.distance = self.type.split("_")[1]
            if self.distance.startswith("pn2"):
                nn.init.constant_(self.weights, 0)
                print("\ninit ==>> ZERO <<== for prototypes!!!")
            elif self.distance.startswith("pnx"):
                self.distance_scale = nn.Parameter(torch.Tensor(1)) 
                nn.init.constant_(self.distance_scale, 1)
                print("\ninit one for distance scale!!!")
                nn.init.normal_(self.weights, mean=0.0, std=1.0)
                print("\ninit normal for prototypes!!!")


        print("\nWEIGHTS/PROTOTYPES SIZE:\n", self.weights.size())
        print("WEIGHTS/PROTOTYPES INITIALIZED [MEAN]:\n", self.weights.mean(dim=0).mean())
        print("WEIGHTS/PROTOTYPES INITIALIZED [STD]:\n", self.weights.std(dim=0).mean())
        print("WEIGHTS/PROTOTYPES INITIALIZED:\n", self.weights)


    def forward(self, features):

        if self.type.startswith("soft"):
            if self.training or self.metrics_evaluation_mode:
                #print("training or in metrics evaluation mode!!!")
                return features
            else:
                #print("pure inferecing!!!")
                affines = features.matmul(self.weights.t()) + self.bias
                logits = affines
                return logits


        elif self.type.startswith("iso"):
            if self.training or self.metrics_evaluation_mode:
                #print("training or in metrics evaluation mode!!!")
                return features
            else:
                if self.distance.startswith("pn2"):
                    print("pn2")
                    distances = utils.euclidean_distances(features, self.weights, 2)
                elif self.distance.startswith("pnx"):
                    print("pnx")
                    distances = torch.abs(self.distance_scale)*utils.euclidean_distances(F.normalize(features), F.normalize(self.weights), 2)
                logits = -distances
                return logits

    def extra_repr(self):
        if self.type.startswith("soft"):
            return 'in_features={}, out_features={}, type={}, bias={}'.format(
                self.in_features, self.out_features, self.type, self.bias is not None)

        elif self.type.startswith("iso"):
            return 'in_features={}, out_features={}, type={}'.format(
                self.in_features, self.out_features, self.type)


class GenericLossSecondPart(nn.Module):
    def __init__(self, loss_first_part):
        super().__init__()
        self.weights = loss_first_part.weights
        self.type = loss_first_part.type
        #self.split = loss_first_part.split
        if self.type.startswith("soft"):
            self.loss = nn.CrossEntropyLoss()
            self.bias = loss_first_part.bias
        elif self.type.startswith("iso"): 
            self.distance = loss_first_part.distance
            if self.distance.startswith("pnx"):
                self.distance_scale = loss_first_part.distance_scale

    #def forward(self, features, targets, last_batch=False):
    #def forward(self, features, targets, augmented=False):
    def forward(self, features, targets):

        if self.type.startswith("soft"):
            targets_one_hot = torch.eye(self.weights.size(0))[targets].long().cuda()
            affines = features.matmul(self.weights.t()) + self.bias
            #affines = torch.mm(features, self.weights.t())
            #distances = utils.euclidean_distances(features, self.weights, 2)

            #print(self.alpha)
            logits = affines
            logits_to_training = logits[:len(targets)]
            logits_to_inference = logits[:len(targets)]

            loss = self.loss(logits_to_training, targets)

            """
            if self.training and self.type.split("_")[11].startswith("oe"):
                print("outlier exposure #2")
                ##loss += 0.5 * -(x[len(in_set[0]):].mean(1) - torch.logsumexp(x[len(in_set[0]):], dim=1)).mean()
                ##loss += 0.5 * -(logits[len(targets):].mean(1) - torch.logsumexp(logits[len(targets):], dim=1)).mean()
                loss += 0.5 * -(logits[len(targets):2*len(targets)].mean(1) - torch.logsumexp(logits[len(targets):2*len(targets)], dim=1)).mean()
                ##uniform_dist = torch.Tensor(slice_size, self.args.number_of_model_classes).fill_((1./self.args.number_of_model_classes)).cuda()
                ##kl_divergence = F.kl_div(F.log_softmax(odd_outputs[:slice_size], dim=1), uniform_dist, reduction='batchmean')
                #uniform_dist = torch.Tensor(len(targets), self.weights.size(0)).fill_((1./self.weights.size(0))).cuda()
                #loss += 0.5 * F.kl_div(F.log_softmax(logits[len(targets):2*len(targets)], dim=1), uniform_dist, reduction='batchmean')
            """

            ####################################################################################################
            #intra_inter_logits = torch.where(targets_one_hot != 0, distances[:len(targets)], torch.Tensor([float('Inf')]).cuda())
            #inter_intra_logits = torch.where(targets_one_hot != 0, torch.Tensor([float('Inf')]).cuda(), distances[:len(targets)])
            intra_inter_logits = torch.where(targets_one_hot != 0, logits[:len(targets)], torch.Tensor([float('Inf')]).cuda())
            inter_intra_logits = torch.where(targets_one_hot != 0, torch.Tensor([float('Inf')]).cuda(), logits[:len(targets)])
            intra_logits = intra_inter_logits[intra_inter_logits != float('Inf')]
            inter_logits = inter_intra_logits[inter_intra_logits != float('Inf')]
            #affines = features.matmul(self.weights.t()) + self.bias
            ######################################################################################################

            cls_probabilities = nn.Softmax(dim=1)(logits_to_inference)
            ood_probabilities = nn.Softmax(dim=1)(logits[:len(targets)])
            max_logits = logits[:len(targets)].max(dim=1)[0]

            return loss, cls_probabilities, ood_probabilities, max_logits, intra_logits, inter_logits


        elif self.type.startswith("iso"):
            targets_one_hot = torch.eye(self.weights.size(0))[targets].long().cuda()

            #print("distance scale:", self.distance_scale)
            if self.distance.startswith("pn2"):
                #print("euclidean")
                print("pn2")
                distances = utils.euclidean_distances(features, self.weights, 2)
            elif self.distance.startswith("pnx"):
                print("pnx")
                distances = torch.abs(self.distance_scale)*utils.euclidean_distances(F.normalize(features), F.normalize(self.weights), 2)
                print("distance scale: {0:.8f}".format(torch.abs(self.distance_scale).item()))

            logits = -distances
            logits_to_training = 10 * logits[:len(targets)]
            logits_to_inference = 10 * logits[:len(targets)]

            probabilities_for_training = nn.Softmax(dim=1)(logits_to_training)
            probabilities_at_targets = probabilities_for_training[range(targets.size(0)), targets]
            loss = -torch.log(probabilities_at_targets).mean()


            """
            if self.training and self.type.split("_")[11].startswith("oe"):
                print("outlier exposure #2")
                ##loss += 0.5 * -(x[len(in_set[0]):].mean(1) - torch.logsumexp(x[len(in_set[0]):], dim=1)).mean()
                ##loss += 0.5 * -(logits[len(targets):].mean(1) - torch.logsumexp(logits[len(targets):], dim=1)).mean()
                loss += 0.5 * -(logits[len(targets):2*len(targets)].mean(1) - torch.logsumexp(logits[len(targets):2*len(targets)], dim=1)).mean()
                ##uniform_dist = torch.Tensor(slice_size, self.args.number_of_model_classes).fill_((1./self.args.number_of_model_classes)).cuda()
                ##kl_divergence = F.kl_div(F.log_softmax(odd_outputs[:slice_size], dim=1), uniform_dist, reduction='batchmean')
                #uniform_dist = torch.Tensor(len(targets), self.weights.size(0)).fill_((1./self.weights.size(0))).cuda()
                #loss += 0.5 * F.kl_div(F.log_softmax(logits[len(targets):2*len(targets)], dim=1), uniform_dist, reduction='batchmean')
            """

            ####################################################################################################
            #intra_inter_logits = torch.where(targets_one_hot != 0, distances[:len(targets)], torch.Tensor([float('Inf')]).cuda())
            #inter_intra_logits = torch.where(targets_one_hot != 0, torch.Tensor([float('Inf')]).cuda(), distances[:len(targets)])
            intra_inter_logits = torch.where(targets_one_hot != 0, -logits[:len(targets)], torch.Tensor([float('Inf')]).cuda())
            inter_intra_logits = torch.where(targets_one_hot != 0, torch.Tensor([float('Inf')]).cuda(), -logits[:len(targets)])
            intra_logits = intra_inter_logits[intra_inter_logits != float('Inf')]
            inter_logits = inter_intra_logits[inter_intra_logits != float('Inf')]
            ######################################################################################################

            cls_probabilities = nn.Softmax(dim=1)(logits_to_inference)
            ood_probabilities = nn.Softmax(dim=1)(logits[:len(targets)])
            max_logits = logits[:len(targets)].max(dim=1)[0]

            return loss, cls_probabilities, ood_probabilities, max_logits, intra_logits, inter_logits


