import os
import pickle
from sklearn import tree
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import time
from cross_f1 import *
from sklearn import metrics
import logging
from utils import *
from torchvision import transforms
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
import PIL
import numpy as np
from matplotlib import pyplot as plt
import skimage, skimage.transform
#from encoder import *
import numpy as np

from pytorch_metric_learning import losses as m_losses
from pytorch_metric_learning import miners

class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x):
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        b = -1 * b.sum()
        return b

def kl(out_s, out_t):
    T = 4.0
    loss = F.kl_div(F.log_softmax(out_s/T, dim=1),
						F.softmax(out_t / T, dim=1),
						reduction='batchmean') * T * T
    return loss

def CXE(predicted, target):
    EPS = 1e-13
    return -((target) * torch.log(torch.softmax(predicted , dim=1))).sum(dim=1).mean()

def get_orth_loss(model , device):
  with torch.enable_grad():
    reg = 1e-6
    orth_loss = torch.zeros(1).to(device)
    for name, param in model.named_parameters():
        if 'bias' not in name:
            param_flat = param.view(param.shape[0], -1)
            sym = torch.mm(param_flat, torch.t(param_flat))
            sym -= torch.eye(param_flat.shape[0]).to(device)
            orth_loss = orth_loss + (reg * sym.abs().sum())
    return orth_loss


def soft_entropy_batched(x, bits=8, T=10):
    '''
    :param x: Normalized (min=0, max=1), non quantized input of shape (batch, n)
    :param bits: number of bits
    :param T: temperature (large T means softer distribution)
    :return: per-sample entropy of shape (batch,)
    '''
    if torch.numel(torch.unique(x)) == 1:
        return 0
    bins = int(2 ** bits)
    centers = torch.linspace(0, bins - 1, bins).to(x)
#    print(centers.view(-1, 1, 1).shape, x.repeat(bins, 1, 1).shape)
    x = (x.repeat(bins, 1, 1) - centers.view(-1, 1, 1)) ** 2
    # bins
    x = x / T
#    print(x.shape)
    x = torch.nn.functional.softmax(x, 0)
    x = torch.mean(x, dim=-1)
#    print(x.shape)
    x[x == 0] = 1  # hack
    x = -x * torch.log(x)
    return torch.sum(x, dim=0) / np.log(2)

# Image resize
def imresize(img, height=None, width=None):
    # load image
    if height is not None and width is not None:
        ny = height
        nx = width
    elif height is not None:
        ny = height
        nx = img.shape[1] * ny / img.shape[0]
    elif width is not None:
        nx = width
        ny = img.shape[0] * nx / img.shape[1]
    else:
        ny = img.shape[0]
        nx = img.shape[1]

    return skimage.transform.resize(img, (int(ny), int(nx)), mode='constant')

def show_heatmaps(imgs, masks, K, enhance=1, title=None, cmap='gist_rainbow'):
    if K > 0:
        _cmap = plt.cm.get_cmap(cmap)
        colors = [np.array(_cmap(i)[:3]) for i in np.arange(0,1,1/K)]
    plt.figure(figsize=(4 * len(imgs), 4))
    if title is not None:
        plt.suptitle(title+'\n', fontsize=24).set_y(1.05)
    for i in range(len(imgs)):
        plt.subplot(1, len(imgs), i + 1)
        img = imgs[i]
        if img.max()<=1:
            img *= 255
        img = np.array(PIL.ImageEnhance.Color(PIL.Image.fromarray(np.uint8(img))).enhance(enhance))
        plt.imshow(img)
        plt.axis('off')
        for k in range(K):
            layer = np.ones((*img.shape[:2],4))
            for c in range(3): layer[:,:,c] *= colors[k][c]
            mask = masks[i][k]
            layer[:,:,3] = mask
            plt.imshow(layer)
            plt.axis('off')
    plt.tight_layout(pad=0, w_pad=0, h_pad=0)
    plt.savefig('1.png')

def read_images(paths , cuda=True):
    paths= paths[:4]
    raw_images = [plt.imread(path) for path in paths]
    raw_images = [imresize(img, 224, 224) for img in raw_images] # resize
    raw_images = np.stack(raw_images)
    # Preprocess
    images = raw_images.transpose((0,3,1,2)).astype('float64') # to numpy, NxCxHxW, float32
    images -= np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1)) # zero mean
    images /= np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1)) # unit variance
#    images = (images * 255).astype(np.uint8)
    images = torch.from_numpy(images) # convert to Pytorch tensor
    if cuda:
      images = images.cuda()
    return images

def plot_attributes(patches , paths):
    raw_images = read_images(paths)
    raw_images = raw_images.permute(0 , 2 , 3 , 1)
    #show_heatmaps(raw_images, None, 0, enhance=1)
    flat_features = patches.permute(0, 2, 3, 1).contiguous().view((-1, patches.size(1)))

    for K in range(1,5):
       if K != 2:
        continue
       with torch.no_grad():
          W, _ = NMF(flat_features, K, random_seed=0, cuda=True, max_iter=50)

       heatmaps = W.cpu().view(patches.size(0), patches.size(2), patches.size(3), K).permute(0,3,1,2) # (N*H*W)xK -> NxKxHxW
       heatmaps = torch.nn.functional.interpolate(heatmaps, size=(224, 224), mode='bilinear', align_corners=False) ## 14x14 -> 224x224
       heatmaps /= heatmaps.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] # normalize by factor (i.e., 1 of K)
       heatmaps = heatmaps.cpu().numpy()
       show_heatmaps(raw_images, heatmaps, K,  title='$k$ = {}'.format(K), enhance=0.3)
       exit(-1)

def train(epoch , train_loader, model, criterion, optimizer, conf,wmodel , decision_tree=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageAccMeter()
    end = time.time()
    model.train()
    recon_loss = nn.MSELoss()
    time_start = time.time()
    pbar = tqdm(train_loader, dynamic_ncols=True, total=len(train_loader))
    mixmethod = None
    clsw = None
    #kl = SoftTarget(4.0)
    entropy = HLoss()

    #cluster = ProxyGML(12 , 12 , 128 , 0.05 , 0.3)
    if 'mixmethod' in conf:
        if 'baseline' not in conf.mixmethod:
            mixmethod = conf.mixmethod
            if wmodel is None:
                wmodel = model
    raw_images = []
    new_tree_x = []
    new_tree_y = []



    miner = miners.BatchEasyHardMiner()
    loss_func = m_losses.SupConLoss()

    features = []
    labels   = []
    paths    = []

    for idx, (input, target , path) in enumerate(pbar):
        labels = []
        path = [img.split('/')[-1] for img in path]
        paths.extend(path)


        data_time.add(time.time() - end)
        input = input.cuda()
        target = target.cuda()
        if 'baseline' not in conf.mixmethod:
            input,target_a,target_b,lam_a,lam_b = eval(mixmethod)(input,target,conf,wmodel)
            raw_images = []
            output,_,moutput , patches , quantized = model(input)
            features.append(quantized.detach())

            loss_a = criterion(output, target_a)
            loss_b = criterion(output, target_b)
            loss = torch.mean(loss_a* lam_a + loss_b* lam_b)

            if 'inception' in conf.netname:
                loss1_a = criterion(moutput, target_a)
                loss1_b = criterion(moutput, target_b)
                loss1 = torch.mean(loss1_a* lam_a + loss1_b* lam_b)
                loss += 0.4*loss1

            if 'midlevel' in conf:
                if conf.midlevel:
                    loss_ma = criterion(moutput, target_a)
                    loss_mb = criterion(moutput, target_b)
                    loss += torch.mean(loss_ma* lam_a + loss_mb* lam_b)
        else:
            output , _ , moutput , patches , quantized  = model(input)
            features.append(quantized.detach())
            loss_tree = None
            new_tree_x.append(quantized.detach())
            new_tree_y.append(output.detach())

            l1_loss = torch.norm(quantized , p=1 , dim=1).mean() * 0.0001

            if decision_tree != None:
             tree_probs = decision_tree.predict(quantized.detach().cpu().numpy())
             tree_probs = F.softmax(torch.from_numpy(tree_probs).to(output.device) , dim=1)
             net_probs = F.softmax(output , dim=1)
             loss_tree = 0.1 * CXE(tree_probs, net_probs)


            loss = torch.mean(criterion(output, target)) + l1_loss
            if loss_tree is not None:
             loss += loss_tree


            if 'inception' in conf.netname:
                loss += 0.4*torch.mean(criterion(moutput,target))

            if 'midlevel' in conf:
                if conf.midlevel:
                    loss += torch.mean(criterion(moutput,target))

        # measure accuracy and record loss
        losses.add(loss.item(), input.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.add(time.time() - end)
        end = time.time()
        pbar.set_postfix(batch_time=batch_time.value(), data_time=data_time.value(), loss=losses.value(), score=0)

    new_tree_x = torch.cat(new_tree_x).cpu().numpy()
    new_tree_y = F.softmax(torch.cat(new_tree_y) , dim=1).cpu().numpy()
    tree  = DecisionTreeRegressor(max_depth=7).fit(new_tree_x , new_tree_y)

    return losses.value() , tree
