from asyncio import constants
import torch
import pickle 
from unet import *
from utilities import *
from forward_process import *
from dataset import *
from visualize import *
from anomaly_map import *
from backbone import *
from metrics import metric
from feature_extractor import *
import time
from datetime import timedelta
from diffusers import AutoencoderKL
import wandb
from sklearn.neighbors import NearestNeighbors





class KNN:
    def __init__(self,config, k=5, num_bins=10):
        self.k = k
        self.config = config
        self.model = NearestNeighbors(n_neighbors=k, metric=config.model.KNN_metric)
        self.num_bins = num_bins
    def fit(self, X):
        X = X.detach().cpu().numpy()
        self.model.fit(X)
        
        distances, _ = self.model.kneighbors(X)
        # compute the average distance for each point
        avg_distances = distances.mean(axis=1)
    
        # define bins based on these average distances
    
        self.histogram, self.bin_edges = np.histogram(avg_distances, self.num_bins)

        
        print(f"bin edges: {self.bin_edges}")
        print(f"histogram: {self.histogram}")

    def transform(self, X):
        X = X.detach().cpu().numpy()
        distances, indices = self.model.kneighbors(X)
        return distances, indices



def get_bins_and_mappings(knn, distances, indices):
    mappings = []
    keys = []
    for i in range(distances.shape[0]):
        avg_distance = np.mean(distances[i])  # average the distances
        
        
        
        
        bin_id = np.digitize(avg_distance, knn.bin_edges, right=True) -1 
        bin_id = min(bin_id, len(knn.bin_edges) - 2) + 1
        keys.append(bin_id)  # append the key to the list
        mapping = {bin_id: [ind.item() for ind in indices[i]]}
        mappings.append(mapping)
    return mappings, keys



def validate(unet, constants_dict, config):

    if config.data.name == 'BTAD' or config.data.name =='VisA':
        
        test_dataset = MVTecDataset(
            root= config.data.data_dir,
            category=config.data.category,
            config = config,
            is_train=False,
        )
        testloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size= config.data.batch_size,
            shuffle=False,
            num_workers= config.model.num_workers,
            drop_last=False,
        )

            
        train_dataset = MVTecDataset(
            root= config.data.data_dir,
            category=config.data.category,
            config = config,
            is_train=True,
        )
        trainloader = torch.utils.data.DataLoader(
            train_dataset,
            #batch_size=config.data.batch_size,
            batch_size=18,
            shuffle=True,
            num_workers=config.model.num_workers,
            drop_last=False,
            ) 
    
    
    
    labels_list = []
    predictions= []
    anomaly_map_list = []
    GT_list = []
    reconstructed_list = []
    forward_list = []
    forward_list_orig = []
    l1_latent_list = []
    cos_dist_list = []
    step_list = []
    filename_list = []
    anomaly_map_recon_list = []
    anomaly_map_feature_list = []
    anomaly_map_latent_list = []
    KNN_feature_list = []
    

    if config.model.latent:
        
                    
                    if config.model.latent_backbone == "VAE":
                        
                      
                        vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
                        vae.to(config.model.device)
 
                        vae.eval()
                    elif config.model.latent_backbone == "wide_resnet50_2":
                        encoder = wide_resnet50_2(pretrained=True)[0]
                    elif config.model.latent_backbone == "wide_resnet101_2":
                        encoder = wide_resnet101_2(pretrained=True)[0]
                    else:
                        encoder= resnet18(pretrained=True)
                    
                    if not config.model.latent_backbone == "VAE":     
                        encoder.to(config.model.device)
                        encoder.eval()

                    if config.model.dynamic_steps or config.model.repeated_sampling or (config.model.distance_metric_eval == "combined"):
                   
                        feature_extractor = resnet34(pretrained=True)[0]
                        
                        feature_extractor.to(config.model.device)
                        feature_extractor = Domain_adaptation(unet, feature_extractor,vae, config, fine_tune=config.model.DA_fine_tune, constants_dict=constants_dict,dataloader=trainloader)   
                        feature_extractor.eval()
                        
                        knn_transform = transforms.Compose([
                                    transforms.Lambda(lambda t: (t + 1) / (2)),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                ])
        
                        knn = KNN(config=config,k=config.model.knn_k,num_bins=10)

                        # We're going to stack the training data here
                        train_stack = []

                        for train_batch in trainloader:
                 
                            train_batch = knn_transform(train_batch[0])
                            train_batch = feature_extractor(train_batch.to(config.model.device))[config.model.DA_layer]
           
                            train_batch = train_batch.flatten(start_dim=1)
                            
                            
                            train_stack.append(train_batch)

                        # fit KNN model on training data
                        knn.fit(torch.cat(train_stack, dim=0))
                        knnPickle = open(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f"knn_{config.model.knn_k}_{config.model.DA_epochs}"), 'wb') 
    
                        # source, destination 
                        pickle.dump(knn, knnPickle)
                        # close the file
                        knnPickle.close()
    

      

    
    def roundup(x, n=10):
        res = np.ceil(x/n)*n
        mask = np.logical_and(x % n < n/2, x % n > 0)
        res[mask] -= n
        return res
              
    if config.data.name == 'BTAD' or config.data.name == "VisA":
        
        with torch.no_grad():
            # here start
            start = time.time()
            for data, targets, labels, filename in testloader:
                
                
                data_placeholder = data
     
                if config.model.dynamic_steps or config.model.repeated_sampling:
                    test_batch = data
                    test_batch = knn_transform(test_batch)
                    test_batch = feature_extractor(test_batch.to(config.model.device))[config.model.DA_layer]

                    test_batch = test_batch.flatten(start_dim=1)
      

                    distances, indices = knn.transform(test_batch)
                    
                    KNN_ano_map = KNN_heat_map(test_batch,train_stack,indices,config)
                    KNN_feature_list.append(KNN_ano_map)
  
                    mappings, keys = get_bins_and_mappings(knn, distances, indices)

        
                    wandb.log({"dynamic step": keys})
                    mapping_int = int(list(set(mappings[0].keys()))[0])
                    if config.model.dynamic_steps:

                        bin_ids_array = np.array(keys)

                        # Compute step_sizes directly using element-wise operations
                        step_sizes_array = np.maximum(bin_ids_array, 2) / 10 * config.model.test_trajectoy_steps2
                        step_size = roundup(step_sizes_array)

                        # Compute skips directly using element-wise operations
                        skip = np.maximum(step_size / 10, 1).astype(int)
                        data = data.to(config.model.device)
                        step_list.extend(keys)
                    else:
                        step_size = config.model.test_trajectoy_steps2
                        skip = config.model.skip2
                    
                    
                else:
                    step_size = config.model.test_trajectoy_steps2
                    skip = config.model.skip2
                    
                filename_list.append(filename)
                forward_list_orig.append(data)
                forward_list.append(data)
                if config.model.latent:
                    if not config.model.latent_backbone == "VAE":     
                        data = encoder(data.to(config.model.device))[0]
                    else:
                        data = vae.encode(data.to(config.model.device)).latent_dist.sample() * 0.18215    
                
            
                test_trajectoy_steps = torch.Tensor([step_size]).type(torch.int64).to(config.model.device)[0]
                

                at = compute_alpha2(constants_dict['betas'], test_trajectoy_steps.long(),config)

                if config.model.noise_sampling:
                    noise = torch.randn_like(data).to(config.model.device)
                    noisy_image = at.sqrt() * data + (1- at).sqrt() * noise
                else:
                    noisy_image = data
                    if config.model.downscale_first:
                        noisy_image = noisy_image * at.sqrt()
                if config.model.dynamic_steps:
                    seq = [torch.arange(0, end, step).to(test_trajectoy_steps.device) for end, step in zip(test_trajectoy_steps, skip)]
                else:
                    seq = range(0 , test_trajectoy_steps, skip)
                

                
                if config.model.dynamic_steps:            
                
                    
                    reconstructed, rec_x0 = my_generalized_steps(data, noisy_image, seq, unet, constants_dict['betas'], config, eta2=config.model.eta2 , eta3=0 , constants_dict=constants_dict ,eraly_stop = False)

                else:
                    reconstructed, rec_x0 = DA_generalized_steps(data, noisy_image, seq, unet, constants_dict['betas'], config, eta2=config.model.eta2 , eta3=0 , constants_dict=constants_dict ,eraly_stop = False)

                data_reconstructed = reconstructed[-1].to(config.model.device)
                
                
                if config.model.latent_backbone == "VAE":
                    #reconstruct image from latent space
                    reconstructed = 1 / 0.18215 * data_reconstructed
                    reconstructed = vae.decode(reconstructed.to(config.model.device)).sample
                
                if config.model.distance_metric_eval == "combined":
                    l1_latent = color_distance(data_reconstructed, data, config, out_size=config.data.image_size)
                    cos_dist = feature_distance_new(reconstructed, data_placeholder, feature_extractor,config)
                    
                    anomaly_map_latent =  recon_heat_map(data_reconstructed,data,config)
                    anomaly_map_feature = feature_heat_map(reconstructed,data_placeholder,feature_extractor,config)
                else:
                    anomaly_map_recon = recon_heat_map(reconstructed,data_placeholder,config)
                    if config.model.dynamic_steps:
                        anomaly_map = heat_map_recon(reconstructed,data_placeholder,data_reconstructed, data, constants_dict, config,feature_extractor)
                    
                    else:
                        anomaly_map = heat_map(data_reconstructed, data, constants_dict, config)
                    
                
                if config.model.distance_metric_eval == "combined":
                    l1_latent_list.append(l1_latent)
                    cos_dist_list.append(cos_dist)
                    
                    anomaly_map_latent_list.append(anomaly_map_latent)
                    anomaly_map_feature_list.append(anomaly_map_feature)
                    
                else:
                    anomaly_map_list.append(anomaly_map)
                    anomaly_map_recon_list.append(anomaly_map_recon)
                GT_list.append(targets)
                reconstructed_list.append(reconstructed)


                if config.model.distance_metric_eval == "combined":
                    for label in labels:
                        labels_list.append(0 if label == 'good' else 1)
                
                else:
                    for pred, label in zip(anomaly_map, labels):
                        labels_list.append(0 if label == 'good' else 1)
                   
                        predictions.append(torch.max(pred).item() )
                
            
    
    if config.model.distance_metric_eval == "combined":
     

        l1_latent_normalized_list = scale_values_between_zero_and_one(l1_latent_list)
        cos_dist_normalized_list = scale_values_between_zero_and_one(cos_dist_list)

        heatmap_latent_list = heatmap_latent(l1_latent_normalized_list,cos_dist_normalized_list, config)

        concat_heatmap = torch.cat(heatmap_latent_list, dim=0) 
        predictions_normalized = []
        for heatmap in concat_heatmap:
            predictions_normalized.append(torch.max(heatmap).item() )
            
  
        threshold = metric(labels_list, predictions_normalized, heatmap_latent_list, GT_list, config)
        
    else:
        threshold = metric(labels_list, predictions, anomaly_map_list, GT_list, config)
    end = time.time()
    print('Inference time is ', str(timedelta(seconds=end - start)))
    print('threshold: ', threshold)
    wandb.log({"inference_time": str(timedelta(seconds=end - start))})
    wandb.log({"threshold": threshold})
    for step in step_list:
        wandb.log({"steps_time_series": step})
    
    if config.model.visual_all:
        

        # Create a zero-filled numpy array of size 10
        test_histo = np.zeros(10, dtype=int)

        # Fill the numpy array with counts
        for val in step_list:
            test_histo[val - 1] += 1

        # Create an array for the x positions of the bars
        x_positions = np.arange(len(knn.histogram)) + 1

        # Plotting the overlapping transparent bar charts
        plt.bar(x_positions, knn.histogram, alpha=0.5, label='Train Avg. Distance Binning')
        plt.bar(x_positions, test_histo, alpha=0.5, label='Test Avg. Distance Binning')
        plt.xticks(x_positions)
        # Adding title, labels, and legend
        plt.title(f"{config.data.category} Binning")
        plt.xlabel('Bins')
        plt.ylabel('Counts')
        plt.legend()

        # Saving the plot to a file
        
        plt.savefig(f"results/{config.data.category}/binning_distribution.png", dpi=300)
    
 
    reconstructed_list = torch.cat(reconstructed_list, dim=0)
    forward_list = torch.cat(forward_list, dim=0)
    if config.model.latent:
        forward_list_orig = torch.cat(forward_list_orig, dim=0)
      
        filename_list = [item for tup in filename_list for item in tup]
        anomaly_map_latent_list = torch.cat(anomaly_map_latent_list, dim=0)
        anomaly_map_feature_list = torch.cat(anomaly_map_feature_list, dim=0)
        if config.model.dynamic_steps:
            KNN_feature_list = torch.cat(KNN_feature_list, dim=0)
        
    GT_list = torch.cat(GT_list, dim=0)
    wandb.log({"inference_time_image": (end - start) / len(predictions_normalized)})
    if config.model.distance_metric_eval == "combined":
        pred_mask = (concat_heatmap> threshold).float()
        #visualize(forward_list, reconstructed_list, GT_list, pred_mask, concat_heatmap, config.data.category, config, forward_list_orig, step_list,filename_list, anomaly_map_recon_list, anomaly_map_latent_list, anomaly_map_feature_list,KNN_feature_list)
    else:
        anomaly_map_list = torch.cat(anomaly_map_list, dim=0)
        pred_mask = (anomaly_map_list> threshold).float()
        
        #visualize(forward_list, reconstructed_list, GT_list, pred_mask, anomaly_map_list, config.data.category, config, forward_list_orig, step_list,filename_list, anomaly_map_recon_list, anomaly_map_latent_list, anomaly_map_feature_list)
    