import torchvision.transforms as transforms

from matplotlib.pyplot import imsave
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from spixelcnn import sp_models
from spixelcnn.train_util import shift9pos, update_spixl_map, get_spixel_image

import warnings
warnings.filterwarnings("ignore")

input_transform = transforms.Compose([
        transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
        transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
    ])


class SPIXELCNN(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        network_data = torch.load(args.pretrained, map_location='cpu')
        self.sp_model = sp_models.__dict__[network_data['arch']]( data = network_data)
        self.sp_model.eval()
        self.min_k_ = args.k_means
        self.max_spread_scale = args.max_spread_scale
        self.num_k_means_steps = args.num_k_means_steps
        self.downsample = transforms.Resize(args.frame_resolution)

    def forward(self, img_):        
        img_ = self.downsample(img_)
        _, _, H, W = img_.shape
        # assign the spixel map
        
        H_, W_  = int(np.ceil(H/self.args.downsize)*self.args.downsize), int(np.ceil(W/self.args.downsize)*self.args.downsize)
        # get spixel id
        n_spixl_h = int(np.floor(H_ / self.args.downsize))
        n_spixl_w = int(np.floor(W_ / self.args.downsize))
        n_spixel =  int(n_spixl_h * n_spixl_w)
        
        spix_values = np.int32(np.arange(0, n_spixl_w * n_spixl_h).reshape((n_spixl_h, n_spixl_w)))
        
        spix_idx_tensor_ = shift9pos(spix_values, h_shift_unit=self.args.shift_unit,  w_shift_unit=self.args.shift_unit)
        spix_idx_tensor = np.repeat(
            np.repeat(spix_idx_tensor_, self.args.downsize, axis=1), self.args.downsize, axis=2)

        spixeIds = torch.from_numpy(np.tile(spix_idx_tensor, (1, 1, 1, 1))).type(torch.float).cuda() 
        
        # compute output
        output_ = self.sp_model.extract_feat(img_) 
        output = F.softmax(output_, dim=1)
        curr_spixl_map, assig_max = update_spixl_map(spixeIds, output)  
        
        ori_sz_spixel_map = F.interpolate(curr_spixl_map.type(torch.float), size=(H_,W_), mode='nearest').type(torch.int) 
        
        outs = []
        for ii in range(len(img_)):
            spixel_viz, spixel_label_map = get_spixel_image(img_[ii], ori_sz_spixel_map[ii].squeeze(), n_spixels= n_spixel,  b_enforce_connect=True)        
            spixel_viz = ((spixel_viz - spixel_viz.min()) / (spixel_viz.max() - spixel_viz.min())) * 2 - 1 
            outs.append(torch.from_numpy(spixel_viz)) 
        
        return torch.stack(outs, 0).to(img_.device)
        
    def show_sam_anns(self, anns):
        if len(anns) == 0:
            return
        sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
        img = torch.ones(sorted_anns[0]['segmentation'].shape, dtype=torch.float)
        for aid, ann in enumerate(sorted_anns):
            img[ann['segmentation']] = aid
        return img

    def get_k_means_feature(self, img_, n_videos=1, n_clusters=0, sam_package={}):
        if isinstance(n_clusters, int):            
            k_ = [max([self.min_k_, n_clusters])] * n_videos
        elif isinstance(n_clusters, list):
            k_ = [max([self.min_k_, nc]) for nc in n_clusters]
            
        img_ = self.downsample(img_)
        _, _, H, W = img_.shape
        
        output_ = self.sp_model.extract_feat(img_)
        output = F.softmax(output_, dim=1)
        output_max, _ = torch.max(output, dim=1, keepdim=True)
        output_max_flat = output_max.view(n_videos, -1) 
        
        # SAM-base Initialization
        if  sam_package['sam_type'] == 'mask_gen':
            self.max_spread_scale = 1
            mask_generator = sam_package['mask_generator']
            sam_masks=[]
            for f_ in img_:
                f_ = (f_ + 0.5).clamp(0, 1) * 255
                f_ = rearrange(f_, 'c h w -> h w c')
                masks = mask_generator.generate(f_.cpu().numpy().astype('uint8'))
                concat_mask = self.show_sam_anns(masks)
                sam_masks.append(concat_mask)                            
            curr_spixl_map = torch.stack(sam_masks, dim=0).unsqueeze(1).to(img_.device)  
            
        elif sam_package['sam_type'] == 'mask_init':
            mask_generator = sam_package['mask_generator']
            sam_masks=[]
            for f_ in img_:
                f_ = (f_ + 0.5).clamp(0, 1) * 255
                f_ = rearrange(f_, 'c h w -> h w c')
                masks = mask_generator.generate(f_.cpu().numpy().astype('uint8'))
                concat_mask = self.show_sam_anns(masks)
                sam_masks.append(concat_mask)                            
            img_sam_masks = torch.stack(sam_masks, dim=0)
            sam_idx_tensor_ = shift9pos(img_sam_masks.int(), h_shift_unit=self.args.shift_unit,  w_shift_unit=self.args.shift_unit, concat_axis=1)
            spixeIds = torch.from_numpy(np.tile(sam_idx_tensor_, (1, 1, 1, 1))).type(torch.float).cuda()
            curr_spixl_map, _ = update_spixl_map(spixeIds, output) 
        else:
            H_, W_  = int(np.ceil(H/self.args.downsize)*self.args.downsize), int(np.ceil(W/self.args.downsize)*self.args.downsize)
            
            n_spixl_h = int(np.floor(H_ / self.args.downsize))
            n_spixl_w = int(np.floor(W_ / self.args.downsize))
            
            spix_values = np.int32(np.arange(0, n_spixl_w * n_spixl_h).reshape((n_spixl_h, n_spixl_w)))
            
            spix_idx_tensor_ = shift9pos(spix_values, h_shift_unit=self.args.shift_unit,  w_shift_unit=self.args.shift_unit)
            spix_idx_tensor = np.repeat(
                np.repeat(spix_idx_tensor_, self.args.downsize, axis=1), self.args.downsize, axis=2)

            spixeIds = torch.from_numpy(np.tile(spix_idx_tensor, (1, 1, 1, 1))).type(torch.float).cuda()
            curr_spixl_map, _ = update_spixl_map(spixeIds, output)           
            
        curr_spixl_map_flat = curr_spixl_map.view(n_videos, -1)
        
        # Coarse-grained clustering, where k=n_clusters
        with torch.cuda.amp.autocast(enabled=False):
            one_hot_ = F.one_hot(curr_spixl_map_flat.to(torch.int64), num_classes=max_index_over_batch).float()
            sums = torch.bmm(output_max_flat.unsqueeze(1).float(), one_hot_).squeeze() 

            # Count the occurrences of each unique value in TenB to average correctly
            counts = one_hot_.sum(dim=1) 

            # Avoid division by zero for classes that do not appear in TenB
            sp_feats = sums / counts.clamp(min=1)
            
            # Run k-means            
            new_cluster = []
            for bid, batch_sp_feats in enumerate(sp_feats):                
                nonzero_ids = torch.nonzero(batch_sp_feats).detach()
                nz_batch_sp_feats = batch_sp_feats[nonzero_ids]
                
                assigned_clusters, centroids = self.overlapping_k_means(samples=nz_batch_sp_feats, n_clusters=k_[bid])
                
                # generate new_cluster_ids considering original indices, 
                # added (empty) indices do not affect the output since there are no sources use this.
                empty_ids = torch.zeros((max_index_over_batch, self.max_spread_scale), device=img_.device, dtype=torch.long)                
                empty_ids[nonzero_ids.squeeze()] = assigned_clusters                
                new_cluster.append(empty_ids[curr_spixl_map_flat[bid].long()])                                
            
        new_cluster_cc = torch.stack(new_cluster, dim=0)
        return rearrange(new_cluster_cc, 'b (f h w) k -> b k f h w', b=n_videos, h=H, w=W), k_
    
    def assign_samples_to_clusters(self, samples, centroids):
        distances = torch.cdist(samples, centroids)  # Compute all pairwise distances between samples and centroids
        assigned_clusters = torch.argsort(distances, dim=1)[:, :self.max_spread_scale]  # Allow overlaps by selecting top closest centroids for each sample
        return assigned_clusters

    def update_centroids(self, samples, assigned_clusters, n_clusters):
        new_centroids = torch.zeros((n_clusters, *samples.shape[1:])).to(samples.device)
        for i in range(n_clusters):
            # Find samples assigned to cluster i (might include overlaps)
            
            # top assigned_clusters
            cluster_samples = samples[assigned_clusters[:, 0] == i]
            
            # Update centroid as the mean of samples (if any are assigned)
            if len(cluster_samples) > 0:
                new_centroids[i] = cluster_samples.mean(dim=0)
        return new_centroids


    # K-means clustering with overlaps
    def overlapping_k_means(self, samples, n_clusters):
        centroids = samples[torch.randperm(len(samples))[:n_clusters]]
        
        for iteration in range(self.num_k_means_steps):  # Run for a fixed number of iterations
            assigned_clusters = self.assign_samples_to_clusters(samples, centroids)
            centroids = self.update_centroids(samples, assigned_clusters, n_clusters)
        return assigned_clusters, centroids