import torch
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
import torchvision
from torchvision.transforms import transforms
import math 
from utilities import *
from backbone import *
from dataset import *
from visualize import *
from feature_extractor import *
# import cv2
import numpy as np


def heat_map(output, target,constants_dict, config):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
    anomaly_map = 0


    output = output.to(config.model.device)
    target = target.to(config.model.device)

        

    
    
        
    
    

    
    # visualalize_distance(output, target, i_d, f_d)
    
    #anomaly_map +=  f_d + 0.6 * i_d    # 2 for W5, 4 for W101
    #anomaly_map += 0.3 * f_d +  torch.max(f_d)/ torch.max(i_d) * i_d 
    # if config.model.latent:
    #     anomaly_map += i_d
    
    # else:
    #     anomaly_map += f_d
    #anomaly_map += 0.5 * f_d + 0.5 * i_d
    #anomaly_map += i_d / torch.max(i_d) + f_d / torch.max(f_d)
    #anomaly_map += f_d +  torch.max(f_d)/ torch.max(i_d)
    #anomaly_map += f_d
    if config.model.distance_metric_eval == "l1":
        i_d = color_distance(output, target, config, config.data.image_size) #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) # 1 - F.cosine_similarity(patchify(output) , patchify(target), dim=1).to(config.model.device).unsqueeze(1) # color_distance(output, target, config) #torch.sqrt(torch.mean(((output)-(target))**2,dim=1).unsqueeze(1))   #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) #color_distance(output, target, config)        ((output)-(target))**2  #torch.mean(torch.abs((output)-(target)),dim=1).unsqueeze(1)
        #print('image_distance max : ',torch.max(i_d))
        anomaly_map += i_d
    
    if config.model.distance_metric_eval == "cosine":
        f_d = feature_distance((output),  (target), constants_dict, config)
        f_d = torch.Tensor(f_d).to(config.model.device)
        # print('image_distance mean : ',torch.mean(i_d))
        # print('feature_distance mean : ',torch.mean(f_d))
    
        print('feature_distance max : ',torch.max(f_d))
        anomaly_map += f_d
        
    

    anomaly_map = gaussian_blur2d(
        anomaly_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
        )
        
    anomaly_map = torch.sum(anomaly_map, dim=1).unsqueeze(1)
    

    # print( 'anomaly_map_mean : ',torch.mean(anomaly_map))
    # print( 'anomaly_map_max : ',torch.max(anomaly_map))
    # print( 'anomaly_map_min : ',torch.min(anomaly_map))

    return anomaly_map

def heat_map_recon(output, target, output_latent, target_latent,constants_dict, config, fe):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
    anomaly_map = 0


    

    output_latent = output_latent.to(config.model.device)
    target_latent = target_latent.to(config.model.device)

    
    
        
    
    

    
    # visualalize_distance(output, target, i_d, f_d)
    
    #anomaly_map +=  f_d + 0.6 * i_d    # 2 for W5, 4 for W101
    #anomaly_map += 0.3 * f_d +  torch.max(f_d)/ torch.max(i_d) * i_d 
    # if config.model.latent:
    #     anomaly_map += i_d
    
    # else:
    #     anomaly_map += f_d
    #anomaly_map += 0.5 * f_d + 0.5 * i_d
    #anomaly_map += i_d / torch.max(i_d) + f_d / torch.max(f_d)
    #anomaly_map += f_d +  torch.max(f_d)/ torch.max(i_d)
    #anomaly_map += f_d
    
    if config.model.reconstruction_comparision:
        output = output.to(config.model.device)
        target = target.to(config.model.device)
        #i_d = color_distance(output_latent, target_latent, config, config.data.image_size) #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) # 1 - F.cosine_similarity(patchify(output) , patchify(target), dim=1).to(config.model.device).unsqueeze(1) # color_distance(output, target, config) #torch.sqrt(torch.mean(((output)-(target))**2,dim=1).unsqueeze(1))   #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) #color_distance(output, target, config)        ((output)-(target))**2  #torch.mean(torch.abs((output)-(target)),dim=1).unsqueeze(1)
        i_d = color_distance(output, target, config, config.data.image_size)
        
        f_d = feature_distance_new((output),  (target), fe,config)
        f_d = torch.Tensor(f_d).to(config.model.device)
        # print('image_distance mean : ',torch.mean(i_d))
        # print('feature_distance mean : ',torch.mean(f_d))
    
        #print('feature_distance max : ',torch.max(f_d))
        #anomaly_map += f_d + torch.max(f_d)/ torch.max(i_d) * i_d  
        anomaly_map += f_d
    else:
        if config.model.distance_metric_eval == "l1":
            i_d = color_distance(output_latent, target_latent, config, config.data.image_size) #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) # 1 - F.cosine_similarity(patchify(output) , patchify(target), dim=1).to(config.model.device).unsqueeze(1) # color_distance(output, target, config) #torch.sqrt(torch.mean(((output)-(target))**2,dim=1).unsqueeze(1))   #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) #color_distance(output, target, config)        ((output)-(target))**2  #torch.mean(torch.abs((output)-(target)),dim=1).unsqueeze(1)
            transform = transforms.Compose([
            # transforms.CenterCrop(224), 
            transforms.Lambda(lambda t: (t + 1) / (2)),

            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            #output = transform(output)
            #target = transform(target)
            #distance_map = torch.mean(torch.abs(patchify(output) - patchify(target)), dim=1).unsqueeze(1)
            
            #i_d_latent = color_distance(patchify(output_latent), patchify(target_latent), config, config.data.image_size) #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) # 1 - F.cosine_similarity(patchify(output) , patchify(target), dim=1).to(config.model.device).unsqueeze(1) # color_distance(output, target, config) #torch.sqrt(torch.mean(((output)-(target))**2,dim=1).unsqueeze(1))   #torch.sqrt(torch.sum(((output)-(target))**2,dim=1).unsqueeze(1)) #color_distance(output, target, config)        ((output)-(target))**2  #torch.mean(torch.abs((output)-(target)),dim=1).unsqueeze(1)
            #print('image_distance max : ',torch.max(i_d))
            #anomaly_map += i_d
            #anomaly_map += distance_map
            anomaly_map += i_d
        
        if config.model.distance_metric_eval == "cosine":
            f_d = feature_distance((output),  (target), constants_dict, config)
            f_d = torch.Tensor(f_d).to(config.model.device)
            # print('image_distance mean : ',torch.mean(i_d))
            # print('feature_distance mean : ',torch.mean(f_d))
        
            print('feature_distance max : ',torch.max(f_d))
            anomaly_map += f_d
        
    

    anomaly_map = gaussian_blur2d(
        anomaly_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
        )
        
    anomaly_map = torch.sum(anomaly_map, dim=1).unsqueeze(1)
    

    # print( 'anomaly_map_mean : ',torch.mean(anomaly_map))
    # print( 'anomaly_map_max : ',torch.max(anomaly_map))
    # print( 'anomaly_map_min : ',torch.min(anomaly_map))

    return anomaly_map

def recon_heat_map(output, target, config):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
    ano_map = 0
    
    output = output.to(config.model.device)
    target = target.to(config.model.device)
    
    i_d = color_distance(output, target, config, config.data.image_size)
    ano_map += i_d
    ano_map = gaussian_blur2d(
        ano_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
        )
    ano_map = torch.sum(ano_map, dim=1).unsqueeze(1)
    
    return ano_map

def feature_heat_map(output,target,fe,config):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
    anomaly_map = 0
    output = output.to(config.model.device)
    target = target.to(config.model.device)

    
    f_d = feature_distance_new((output),  (target), fe,config)
    f_d = torch.Tensor(f_d).to(config.model.device)

    anomaly_map += f_d
    
    anomaly_map = gaussian_blur2d(
        anomaly_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
        )
        
    anomaly_map = torch.sum(anomaly_map, dim=1).unsqueeze(1)
    

    return anomaly_map

def KNN_heat_map(batch,train_stack, indices,config):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
    
    train_stack = torch.cat(train_stack, dim=0)
    #print(f"size of trainbatch {len(train_stack)}")
    batch_list = []
    for i, data in enumerate(batch):
        avg_map_list = []
        for idx in indices[i]:
            ano_map = 0
            #print(f"idx: {idx}")
            #print(f"size of trainbatch {train_stack[idx].size()}")
            output = batch[i].view(256,16,16).unsqueeze(0)
            target = train_stack[idx].view(256,16,16).unsqueeze(0)
            #print(f"size of trainbatch after {target.size()}")
            i_d = color_distance(output, target, config, config.data.image_size)
            ano_map += i_d
            ano_map = gaussian_blur2d(
                ano_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
                )
            ano_map = torch.sum(ano_map, dim=1).unsqueeze(1)
            #print(f"size of anomap after gaussian and unsqueeze {ano_map.size()}")
            avg_map_list.append(ano_map)
        avg_map = torch.cat(avg_map_list, dim=0)
        #print(f"size avg cat: {avg_map.size()}")
        avg_map = torch.mean(avg_map,dim=0).unsqueeze(0)
        #print(f"size avg map: {avg_map.size()}")
        batch_list.append(avg_map)
    knn_heat_map = torch.cat(batch_list, dim=0) 
        
   

    return knn_heat_map

def heatmap_latent(l1_latent,cos_list, config):
    sigma = 4
    kernel_size = 2 * int(4 * sigma + 0.5) +1
 
    heatmap_latent_list = []
    for i in range(len(l1_latent)):

        anomaly_map = config.model.anomap_weighting * l1_latent[i] +(1-config.model.anomap_weighting)*cos_list[i]
    
        anomaly_map = gaussian_blur2d(
            anomaly_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
            )
            
        anomaly_map = torch.sum(anomaly_map, dim=1).unsqueeze(1)
        heatmap_latent_list.append(anomaly_map)
        
    return heatmap_latent_list



def color_distance(image1, image2, config,out_size=256):
 

    transform = transforms.Compose([ 
        transforms.Lambda(lambda t: (t + 1) / (2)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    if config.model.latent:
        image1 = image1.to(config.model.device)
        image2 = image2.to(config.model.device)
        
        distance_map = torch.mean(torch.abs(image1 - image2), dim=1).unsqueeze(1)
        
        distance_map = F.interpolate(distance_map, size=out_size, mode='bilinear', align_corners=True)
    
        
    else:
        image1 = transform(image1)
        image2 = transform(image2)
        distance_map = torch.mean(torch.abs(image1 - image2), dim=1).unsqueeze(1)

    return distance_map


def cal_anomaly_map(fs_list, ft_list, config, out_size=256, amap_mode='mul'):
    out_size = config.data.image_size
    
    if config.model.multi_feature:
        if amap_mode == 'mul':
            anomaly_map = torch.ones([fs_list[0].shape[0], 1, out_size, out_size]).to(config.model.device)
        else:
            anomaly_map = torch.zeros([fs_list[0].shape[0] ,1 ,out_size, out_size]).to(config.model.device)
            
  
        a_map_list = []
 
        
        for i in range(len(ft_list)):
            if i == 0:
                continue
            fs = fs_list[i]
            ft = ft_list[i]
     
            a_map = 1 - F.cosine_similarity(patchify(fs), patchify(ft))

            
            a_map = torch.unsqueeze(a_map, dim=1)
            a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=False)
            
            if amap_mode == 'mul':
                anomaly_map *= a_map
            else:
                anomaly_map += a_map
       
    else:
        
        if amap_mode == 'mul':
            anomaly_map = torch.ones([fs_list.shape[0], 1, out_size, out_size]).to(config.model.device)
        else:
            anomaly_map = torch.zeros([fs_list.shape[0] ,1 ,out_size, out_size]).to(config.model.device)
            
  
        a_map_list = []
        fs = fs_list
        ft = ft_list
        
        
   
        a_map = 1 - F.cosine_similarity(fs, ft)
        print(f"size a_map cosine_distance: {a_map.shape}")
        
        a_map = torch.unsqueeze(a_map, dim=1)
        print(f"size a_map after unsqueeze: {a_map.shape}")
        
        a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=False)
        print(f"size a_map after interpolate: {a_map.shape}")
        
        if amap_mode == 'mul':
            anomaly_map *= a_map
        else:
            anomaly_map += a_map
        
        
    return anomaly_map, a_map_list

def feature_distance_new(output, target, FE, config):
    '''
    Feature distance between output and target
    '''
    FE.eval()
    transform = transforms.Compose([
            transforms.Lambda(lambda t: (t + 1) / (2)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    output = output.to(config.model.device)
    target = target.to(config.model.device)
    target = transform(target)
    output = transform(output)
    inputs_features = FE(target)
    output_features = FE(output)
    out_size = config.data.image_size
    anomaly_map = torch.zeros([inputs_features[0].shape[0] ,1 ,out_size, out_size]).to(config.model.device)
    for i in range(len(inputs_features)):
        
        if i == 0:
            continue
        
        #a_map = 1 - F.cosine_similarity(patchify(inputs_features[i]), patchify(output_features[i]))
        a_map = 1 - F.cosine_similarity(inputs_features[i], output_features[i])
        a_map = torch.unsqueeze(a_map, dim=1)
        a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True)
        anomaly_map += a_map
    return anomaly_map 




def patchify(features, return_spatial_info=False):
    """Convert a tensor into a tensor of respective patches.
    Args:
        x: [torch.Tensor, bs x c x w x h]
    Returns:
        x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
        patchsize]
    """
    patchsize = 3
    stride = 1
    padding = int((patchsize - 1) / 2)
    unfolder = torch.nn.Unfold(
        kernel_size=patchsize, stride=stride, padding=padding, dilation=1
    )
    unfolded_features = unfolder(features)
    number_of_total_patches = []
    for s in features.shape[-2:]:
        n_patches = (
            s + 2 * padding - 1 * (patchsize - 1) - 1
        ) / stride + 1
        number_of_total_patches.append(int(n_patches))
    unfolded_features = unfolded_features.reshape(
        *features.shape[:2], patchsize, patchsize, -1
    )
    unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
    max_features = torch.mean(unfolded_features, dim=(3,4))
    features = max_features.reshape(features.shape[0], int(math.sqrt(max_features.shape[1])) , int(math.sqrt(max_features.shape[1])), max_features.shape[-1]).permute(0,3,1,2)
    if return_spatial_info:
        return unfolded_features, number_of_total_patches
    return features


def unpatch_scores(x, batchsize):
        return x.reshape(batchsize, -1, *x.shape[1:])
    
    
    
    

def normalize_list_of_tensors(tensors):
    normalized_tensors = []
    mean, std = calculate_mean_std_of_tensors(tensors)
    for tensor in tensors:
        
        # Normalize the tensor using the mean and standard deviation
        normalized_tensor = (tensor - mean) / std
        
        normalized_tensors.append(normalized_tensor)
    
    return normalized_tensors


def calculate_mean_std_of_tensors(tensors):
    concatenated_tensor = torch.cat(tensors, dim=0)  # Concatenate all tensors along the batch dimension (0th dimension)
    mean = concatenated_tensor.mean()
    std = concatenated_tensor.std()
    
    return mean, std



def calculate_min_max_of_tensors(tensors):
    # Use a list comprehension to get all min and max values across tensors
    min_values = [torch.min(tensor) for tensor in tensors]
    max_values = [torch.max(tensor) for tensor in tensors]
    
    # Determine the global min and max values
    min_value = torch.min(torch.tensor(min_values))
    max_value = torch.max(torch.tensor(max_values))
    
    return min_value, max_value

def scale_values_between_zero_and_one(tensors):
    min_value, max_value = calculate_min_max_of_tensors(tensors)
    
    # Use list comprehension and broadcasting to scale all tensors
    scaled_tensors = [(tensor - min_value) / (max_value - min_value) for tensor in tensors]
    
    return scaled_tensors