import torchvision
import matplotlib.cm as cm
import torch
from lightning import LightningModule
from torchmetrics import MeanMetric,MinMetric
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.module.utils import *
import cv2


class Net(nn.Module):
    def __init__(self, pixel_samples, output,depth):
        super().__init__()
        self.target = output
        self.pixel_samples = pixel_samples
        self.depth = depth 
        self.glc_smoothing = True
        self.input_dim = 4 
        self.image_encoder = ScaleInvariantSpatialLightImageEncoder(self.input_dim, self.depth, use_efficient_attention=False) 
        self.mode = "Test"
        self.input_dim = 0 
        self.glc_upsample = GLC_Upsample(256+self.input_dim, num_enc_sab=1, dim_hidden=256, dim_feedforward=1024, use_efficient_attention=True)
        self.glc_aggregation = GLC_Aggregation(256+self.input_dim, num_agg_transformer=2, dim_aggout=384, dim_feedforward=1024, use_efficient_attention=False)
        self.img_embedding = nn.Sequential(
            nn.Linear(3,32),
            nn.LeakyReLU(),
            nn.Linear(32, 256)
        )
        self.regressor = Regressor(384, num_enc_sab=1, use_efficient_attention=True, dim_feedforward=1024, output=self.target)
        pretrain_weight = torch.load("ckpt/lino_test.ckpt",weights_only=False)
        self.load_state_dict(pretrain_weight, strict=False)
        del pretrain_weight
        
    def forward_test(self, I, M, nImgArray, decoder_resolution, canonical_resolution): 
        decoder_resolution = decoder_resolution[0,0].cpu().numpy().astype(np.int32).item()
        canonical_resolution = canonical_resolution[0,0].cpu().numpy().astype(np.int32).item()
        self.pixel_samples = 2048
     
        """init"""
        B, C, H, W, Nmax = I.shape
   
        patch_size = 512               
        patches_I = decompose_tensors.divide_tensor_spatial(I.permute(0,4,1,2,3).reshape(-1, C, H, W), block_size=patch_size, method='tile_stride')
        patches_I = patches_I.reshape(B, Nmax, -1, C, patch_size, patch_size).permute(0, 2, 3, 4, 5, 1)
        sliding_blocks = patches_I.shape[1]
        patches_M = decompose_tensors.divide_tensor_spatial(M, block_size=patch_size, method='tile_stride')
            
        patches_nml = []
      
        for k in range(sliding_blocks):
            """ Image Encoder at Canonical Resolution """
            print("please wait for a moment, it may take a while")
            I = patches_I[:, k, :, :, :, :] 
            M = patches_M[:, k, :, :, :] 
            B, C, H, W, Nmax = I.shape
            decoder_resolution = H
            I_enc = I.permute(0, 4, 1, 2, 3)
            M_enc = M 
            img_index = make_index_list(Nmax, nImgArray) 
            I_enc = I_enc.reshape(-1, I_enc.shape[2], I_enc.shape[3], I_enc.shape[4]) 
            M_enc = M_enc.unsqueeze(1).expand(-1, Nmax, -1, -1, -1).reshape(-1, 1, H, W) 
            data = I_enc * M_enc 
            data = data[img_index==1,:,:,:] 
            glc,_= self.image_encoder(data, nImgArray, canonical_resolution)
            I_dec = []
            M_dec = []
            img = I.permute(0, 4, 1, 2, 3)            
            """ Sample Decoder at Original Resokution"""
            img = img.squeeze()
            I_dec = F.interpolate(img.float(), size=(decoder_resolution, decoder_resolution), mode='bilinear', align_corners=False).to(torch.bfloat16) 
            M_dec = F.interpolate(M.float(), size=(decoder_resolution, decoder_resolution), mode='nearest').to(torch.bfloat16)
            decoder_imgsize = (decoder_resolution, decoder_resolution)
            C = img.shape[1]
            H = decoder_imgsize[0]
            W = decoder_imgsize[1]     
            nout = torch.zeros(B, H * W, 3).to(I.device)
            if self.glc_smoothing:  
                f_scale = decoder_resolution//canonical_resolution 
                smoothing = gauss_filter.gauss_filter(glc.shape[1], 10 * f_scale+1, 1).to(glc.device) 
                glc = smoothing(glc)
            del M
            _, _, H, W = I_dec.shape         
            p = 0
            o_ids_list = []
            glc_ids_list = []
            nout = torch.zeros(B, H * W, 3).to(I.device)
            conf_out = torch.zeros(B, H * W, 1).to(I.device)
            for b in range(B):
                target = range(p, p+nImgArray[b])
                p = p+nImgArray[b]
                m_ = M_dec[b, :, :, :].reshape(-1, H * W).permute(1,0)        
                ids = np.nonzero(m_>0)[:,0]  
                ids = ids[np.random.permutation(len(ids))]
                ids_shuffle = ids[np.random.permutation(len(ids))]  
                num_split = len(ids) // self.pixel_samples + 1
                idset = np.array_split(ids_shuffle, num_split) 
                o_ = I_dec[target, :, :, :].reshape(nImgArray[b], C, H * W).permute(2,0,1)  
                for ids in idset: 
                    o_ids = o_[ids, :, :]
                    glc_ids = glc[target, :, :, :].permute(2,3,0,1).flatten(0,1)[ids,:,:] 
                    if self.mode!="Test":                   
                        o_ids_list.append(o_ids)
                        glc_ids_list.append(glc_ids)
                       
                    else:
                        o_ids = self.img_embedding(o_ids) 
                        x = o_ids + glc_ids
                        glc_ids = self.glc_upsample(x)
                        x = o_ids + glc_ids
                        x = self.glc_aggregation(x)  
                        x_n, _, _, conf = self.regressor(x, len(ids)) 
                        x_n = F.normalize(x_n, p=2, dim=-1)
                        nout[b, ids, :] = x_n[b,:,:]
                        conf_out[b, ids, :] = conf[b,:,:].to(torch.float32)
                if self.mode == "Test":
                    nout = nout.reshape(B,H,W,3).permute(0,3,1,2)
                    conf_out = conf_out.reshape(B,H,W,1).permute(0,3,1,2)
                    patches_nml.append(nout)

        patches_nml = torch.stack(patches_nml, dim=1)
        merged_tensor_nml = decompose_tensors.merge_tensor_spatial(patches_nml.permute(1,0,2,3,4), method='tile_stride')
        return merged_tensor_nml,0,0,0
class LINOModule(LightningModule):
    def __init__(
        self,
        net: torch.nn.Module ,
        canonical_resolution: int,
        sample_num: int,
        numberofImages:int,
        task_name:str,
    ) -> None:

        super().__init__()
        self.numberofImages = numberofImages
        self.task_name = task_name
        
        
        self.save_hyperparameters(logger=False)
        self.canonical_resolution = canonical_resolution
        self.net = net
        self.sample_num = sample_num
        
        self.criterion = torch.nn.MSELoss()


        
        self.train_mae = MeanMetric()
        self.val_mae = MeanMetric()
        self.test_mae = MeanMetric()

        
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        
        self.val_mae_best = MinMetric()

    
    def forward_test(self, img,mask,nImgArray,decoder_resolution,canonical_resolution_,sample_num):
  

        with torch.autocast(device_type="cuda",dtype=torch.bfloat16):
            return self.net.forward_test(img,mask,nImgArray,decoder_resolution,canonical_resolution_)


    def model_test_step(
            self, batch
            ):
        
        img, nml, mask,directlist,roi = batch
        B,C,H,W,N = img.shape
        nImgArray =torch.full((B,1),N).to(img.device).to(torch.int32)
        decoder_resolution = torch.tensor(H).to(img.device).float() * torch.ones(B,1).to(img.device).float()
        canonical_resolution_ = torch.tensor(self.canonical_resolution).to(img.device).float() * torch.ones(B,1).to(img.device).float() 
        nml_predict,_,_,_ = self.forward_test(img,mask,nImgArray,decoder_resolution,canonical_resolution_,self.sample_num) 
        return nml_predict

    def test_step(self, batch, batch_idx) -> None:
        img, nml, mask_,directlist,roi = batch[:5]
        if len(batch)==6:
            mask_ = batch[5]
        nml = nml.squeeze().permute(1,2,0).float().cpu().numpy()
        self.net.eval()
        self.net = self.net.to(torch.bfloat16)
        self.net.mode = 'Test'
        roi = roi[0].cpu().numpy()
        h_ = roi[0] 
        w_ = roi[1] 
        r_s = roi[2]
        r_e = roi[3]
        c_s = roi[4]
        c_e = roi[5]
        nml_predict = self.model_test_step(batch[:5]).squeeze().permute(1,2,0).cpu().numpy()
        nml_predict = cv2.resize(nml_predict, dsize=(c_e-c_s, r_e-r_s), interpolation=cv2.INTER_AREA)
        mask = np.float32(np.abs(1 - np.sqrt(np.sum(nml_predict * nml_predict, axis=2))) < 0.5)
        nml_predict = np.divide(nml_predict, np.linalg.norm(nml_predict, axis=2, keepdims=True) + 1.0e-12)
        nml_predict = nml_predict * mask[:, :, np.newaxis] 
        nout = np.zeros((h_, w_, 3), np.float32)
        nout[r_s:r_e, c_s:c_e,:] = nml_predict
        mask_out = np.zeros((h_, w_), np.float32)
        mask_out[r_s:r_e, c_s:c_e] = mask
        obj_name = os.path.basename(os.path.dirname(directlist[0][0]))
        save_path = 'output/' + self.task_name +'/results/'+  obj_name + '/'
        os.makedirs(save_path,exist_ok=True)
        torchvision.utils.save_image(img.squeeze(0).permute(3,0,1,2), save_path + 'tiled.png')
        nout = (nout + 1) / 2 
        nout = nout * mask_out[:,:,np.newaxis] 
        nout = np.concatenate((nout[:,:,::-1], mask_out[:,:,np.newaxis]), axis=2)
        cv2.imwrite(save_path + 'nml_predict.png', (nout*255).astype(np.uint8)) 
        print("Done")

   



