import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os
import cv2
from torchvision import transforms
from utils.tools import load_yaml,save_args_to_txt
import importlib
from datetime import datetime
from tqdm import tqdm

class SpatialTransform(nn.Module):
    def __init__(self, h2, w2, device = 'cpu'):
        super(SpatialTransform, self).__init__()
        self.grid_h, self.grid_w = torch.meshgrid([torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)])
        self.grid_h = self.grid_h.to(device).float()
        self.grid_w = self.grid_w.to(device).float()
        self.grid_w = nn.Parameter(self.grid_w, requires_grad=False)
        self.grid_h = nn.Parameter(self.grid_h, requires_grad=False)
        
    def forward(self, mov_image, flow, mod = 'bilinear'):
        '''mov_image (B,C,H,W), flow (B,H,W,2)
        '''

        flow_h = flow[:,:,:,0]
        flow_w = flow[:,:,:,1]

        disp_h = (self.grid_h + (flow_h)).squeeze(1)
        disp_w = (self.grid_w + (flow_w)).squeeze(1)

        sample_grid = torch.stack((disp_w, disp_h), 3)  # shape (N, H, W, 2)
        warped = torch.nn.functional.grid_sample(mov_image, sample_grid, mode = mod, align_corners = True)
        
        return warped


class DiffeomorphicTransform(nn.Module):
    def __init__(self, time_step=7, device='cpu'):
        super(DiffeomorphicTransform, self).__init__()
        self.time_step = time_step
        self.device = device

    def forward(self, flow):
        
        h2, w2 = flow.shape[-2:]
        grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)])
        grid_h = grid_h.to(self.device).float()
        grid_w = grid_w.to(self.device).float()
        grid_w = nn.Parameter(grid_w, requires_grad=False)
        grid_h = nn.Parameter(grid_h, requires_grad=False)
        flow = flow / (2 ** self.time_step)
        for i in range(self.time_step):
            flow_h = flow[:,0,:,:]
            flow_w = flow[:,1,:,:]
            disp_h = (grid_h + flow_h).squeeze(1)
            disp_w = (grid_w + flow_w).squeeze(1)
            deformation = torch.stack((disp_w,disp_h), dim=3)
            flow = flow + torch.nn.functional.grid_sample(flow, deformation, mode='bilinear',
                                                          padding_mode="border", align_corners = True)

        return flow


def getFlow(Imgs, model, time_step=7):

    N, _, H, W = Imgs.shape
    device = Imgs.device
    
    flow = torch.zeros((N-1, 2, H, W), device=device)
    diff_transform = DiffeomorphicTransform(time_step,device=device)
    
    for i in range(N-1):
        with torch.no_grad():
            f_xy = model(Imgs[i:i+1], Imgs[i+1:i+2])  

        D_f_xy = diff_transform(f_xy)  # [1,2,H,W]
        
        D_f_xy[:,0,:,:] = D_f_xy[:,0,:,:] * (H - 1) / 2 
        D_f_xy[:,1,:,:] = D_f_xy[:,1,:,:] * (W - 1) / 2 
        
        flow[i] = D_f_xy.squeeze(0)
    
    return flow


def getTraj(flow, N, H, W, device='cpu'):

    Traj = torch.zeros((N, 2, H, W), dtype=torch.float32, device=device)

    y_grid, x_grid = torch.meshgrid(
        torch.linspace(0, H-1, H, device=device),
        torch.linspace(0, W-1, W, device=device),
        indexing='ij'
    )
    Traj[0, 0, :, :] = y_grid
    Traj[0, 1, :, :] = x_grid
    

    for i in range(1, N):
    
        prev_traj = Traj[i-1].unsqueeze(0)  # [1,2,H,W]
        
        grid_y = (prev_traj[:,0] / (H-1)) * 2 - 1  
        grid_x = (prev_traj[:,1] / (W-1)) * 2 - 1
        sample_grid = torch.stack((grid_x, grid_y), dim=-1)  # [1,H,W,2]

        current_flow = flow[i-1].unsqueeze(0)  # [1,2,H,W]
        interpolated_flow = torch.nn.functional.grid_sample(
            current_flow, 
            sample_grid, 
            mode='bilinear', 
            align_corners=True
        )  # [1,2,H,W]

        Traj[i:i+1,:,:,:] = prev_traj - interpolated_flow 
        
        Traj[i,0] = torch.clamp(Traj[i,0], 0, H-1)
        Traj[i,1] = torch.clamp(Traj[i,1], 0, W-1)
    
    return Traj


def getSmoothOffset(Traj, Smoothmodel, H, W):

    Smoothmodel.eval()
    
    with torch.no_grad():
        # [T,2,H,W] => [1,2,H,W,T]
        T = Traj.shape[0]
        traj_normalized = Traj.permute(1,2,3,0).unsqueeze(0)  # [1,2,H,W,T]
        traj_normalized[:,0,:,:,:] = (traj_normalized[:,0,:,:,:] / (H-1)) * 2 - 1  # Y
        traj_normalized[:,1,:,:,:] = (traj_normalized[:,1,:,:,:] / (W-1)) * 2 - 1  # X
        offset_norm = Smoothmodel(traj_normalized)  # [1,2,H,W,T-1]

    return traj_normalized[0].permute(3,0,1,2), offset_norm[0].permute(3,0,1,2)  # [T-1,2,H,W]



def smooth_optim(Traj, alpha=0.0001, lr=1e-4, niters=2000, verbose=True):

    device = Traj.device
    pred_traj = Traj.clone().detach().to(device).requires_grad_(True)
    optimizer = torch.optim.Adam([pred_traj], lr=lr)

    for epoch in range(niters):
        optimizer.zero_grad()

        lap = pred_traj[2:] - 2 * pred_traj[1:-1] + pred_traj[:-2]

        smooth_loss = torch.mean(lap ** 2)
        fidelity_loss = torch.mean((pred_traj - Traj) ** 2)
        total_loss = smooth_loss + alpha * fidelity_loss

        total_loss.backward()
        optimizer.step()

        if verbose and epoch % 200 == 0:
            print(f'Epoch {epoch:04d} | Loss: {total_loss.item():.6f}')

    return pred_traj.detach()


def smooth_gauss_seidel(Traj, lambda_=0.0001, niters=100):

    
    N, C, H, W = Traj.shape
    y = Traj
    x = y.clone()        

    for iter in range(niters):
        for i in range(N):
            if i == 0 :
                continue  

            left = x[i - 1] if i > 0 else x[i]
            right = x[i + 1] if i < N - 1 else x[i]

            numerator = lambda_ * y[i] + left + right
            # denominator = lambda_ + (1 if i > 0 else 0) + (1 if i < N - 1 else 0)
            denominator = lambda_ + 2
            x[i] = numerator / denominator

    return x




def scatter_to_grid(coords, values, H, W):
    device = coords.device
    
    y = coords[:, 0]
    x = coords[:, 1]

    y_floor = torch.floor(y).long()
    x_floor = torch.floor(x).long()
    y_ceil = y_floor + 1
    x_ceil = x_floor + 1

    dy = y - y_floor.float()
    dx = x - x_floor.float()
    w1 = (1 - dx) * (1 - dy)
    w2 = (1 - dx) * dy        
    w3 = dx * (1 - dy)       
    w4 = dx * dy             
    
    grid = torch.zeros(H, W, values.size(1), device=device)
    weight_sum = torch.zeros(H, W, device=device)

    def safe_scatter(indices_y, indices_x, weight, value):
        valid = (indices_y >= 0) & (indices_y < H) & (indices_x >= 0) & (indices_x < W)
        grid.index_put_((indices_y[valid], indices_x[valid]), 
                       value[valid] * weight[valid].unsqueeze(-1), 
                       accumulate=True)
        weight_sum.index_put_((indices_y[valid], indices_x[valid]), 
                             weight[valid], 
                             accumulate=True)

    safe_scatter(y_floor, x_floor, w1, values)
    safe_scatter(y_ceil, x_floor, w2, values)
    safe_scatter(y_floor, x_ceil, w3, values)
    safe_scatter(y_ceil, x_ceil, w4, values)

    eps = 1e-8
    return grid / (weight_sum.unsqueeze(-1) + eps)



def getGridOffset(Traj, offset):

    device = Traj.device
    T, _, H, W = Traj.shape
    grid_offset = torch.zeros((T-1, H, W, 2), device=device)
    
    for z in range(T-1):
        grid_offset[z] = scatter_to_grid(Traj[z+1].permute(1,2,0).reshape(-1, 2), 
                                         offset[z+1].permute(1,2,0).reshape(-1, 2),
                                         H, W) # [H*W,2],[H*W,2] --> [H,W,2]
    return grid_offset



def getRegImgs(Imgs, grid_offset):

    device = Imgs.device
    T, C, H, W = Imgs.shape
    regImgs = torch.zeros((T, C, H, W), device=device)
    regImgs[0] = Imgs[0]
    STN = SpatialTransform(H, W, device=device)
    for z in range(T-1):
        img = Imgs[z+1:z+2] # [1,1,H,W]
        flow = grid_offset[z:z+1] # [1,H,W,2] [-1,1]
        regImgs[z+1:z+2] = STN(img, flow) # [1,1,H,W]
        
    return regImgs


class FlowSmoothing(nn.Module):
    def __init__(self, lambda_=0.1, iterations=10):
        super(FlowSmoothing, self).__init__()
        self.lambda_ = lambda_
        self.iterations = iterations
        
        self.conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2, bias=False)
        
        kernel = torch.tensor([[1.0, 2.0, 1.0],
                               [2.0, 4.0, 2.0],
                               [1.0, 2.0, 1.0]], dtype=torch.float32) / 16.0
        kernel = kernel.unsqueeze(0).unsqueeze(0)  # [1, 1, 3, 3]
        kernel = kernel.repeat(2, 1, 1, 1)  # [2, 1, 3, 3]

        self.conv.weight.data = kernel

    def forward(self, flow):
        # [1,H,W,2]
        flow = flow.permute(0,3,1,2) # [1,2,H,W]
        smoothed_flow = flow.clone()
        
        for _ in range(self.iterations):
            smoothed_flow = self.conv(smoothed_flow)  
            smoothed_flow = self.lambda_ * flow + (1 - self.lambda_) * smoothed_flow
        
        return smoothed_flow.permute(0,2,3,1) # [1,H,W,2]




def restore_model(model, pretrained_file):
    
    weights = torch.load(pretrained_file,map_location='cpu')['net']
    
    model_keys = set(model.state_dict().keys())
    weight_keys = set(weights.keys())
    # load weights by name
    weights_not_in_model = sorted(list(weight_keys - model_keys))
    model_not_in_weights = sorted(list(model_keys - weight_keys))
    if len(model_not_in_weights):
        print('Warning: There are weights in model but not in pre-trained.')
        for key in (model_not_in_weights):
            print(key)
            weights[key] = model.state_dict()[key]
    if len(weights_not_in_model):
        print('Warning: There are pre-trained weights not in model.')
        for key in (weights_not_in_model):
            print(key)
        from collections import OrderedDict
        new_weights = OrderedDict()
        for key in model_keys:
            new_weights[key] = weights[key]
        weights = new_weights

    model.load_state_dict(weights)
    return model


def postprocessing(warpTensor):
    warp = ((np.array((warpTensor).cpu().data.numpy().squeeze()))*255).astype(np.uint8)
    warp = np.asarray(warp) 
    return warp
    

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", type=str, default='EFPL_test1')
    parser.add_argument("--base_input", type=str, default='')
    parser.add_argument("--base_output", type=str, default='')
    parser.add_argument("--output_name", type=str, default=None)
    parser.add_argument("--base_traj_model_path", type=str, default='')
    parser.add_argument("--traj_name", type=str, default='exp_2025-04-20-20-12-46')
    parser.add_argument("--traj_ckpt_num", type=int, default=490)

    parser.add_argument("--gs_lamba_", type=float, default=0.01)
    parser.add_argument("--gs_niters", type=int, default=400)
    parser.add_argument("--sm_lamba_", type=float, default=0.001)
    parser.add_argument("--sm_niters", type=int, default=100)
    
    parser.add_argument("--radius", type=int, default=100)
    parser.add_argument("--overlap", type=int, default=2)
    parser.add_argument("--device", type=str, default='cuda:0')
    parser.add_argument("--total_iters", type=int, default=2)
    parser.add_argument("--save_flow", type=int, default=0)
    
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
    
    
    # ---------------------------model--------------------------
    args.traj_model_path = os.path.join(args.base_traj_model_path,args.traj_name,'checkpoints',f'net_{args.traj_ckpt_num:05d}.pth')
    args.traj_cfg_path = os.path.join(args.base_traj_model_path,args.traj_name,'config.yaml')
    traj_cfg = load_yaml(args.traj_cfg_path)
    args.size = traj_cfg['data']['image_size']
    
        
    for arg, value in vars(args).items():
        print(f"{arg}: {value}\n")
    
    
    traj_model = importlib.import_module('models.'+traj_cfg['model']['model_name'])
    traj_model = traj_model.Framework(in_channel=2, n_classes=2, start_channel=traj_cfg['model']['start_ch']).to(args.device)
    traj_model = restore_model(traj_model,args.traj_model_path)
    traj_model.eval()
    
    smooth_flow = FlowSmoothing(lambda_=args.sm_lamba_, iterations=args.sm_niters).to(args.device)
    # ----------------------------------------------------------------------
    
    args.imgs_path = os.path.join(args.base_input, args.data_name)
    
    # output
    if args.output_name:
        args.out_dir = os.path.join(args.base_output,args.output_name,'results')
        os.makedirs(args.out_dir,exist_ok=True)
        save_args_to_txt(args,
                        os.path.join(args.base_output,
                                    args.output_name,
                                    'config.txt')
                        )
        if args.save_flow:
            args.save_flow_path = os.path.join(args.base_output,args.output_name,'flow')
            os.makedirs(args.save_flow_path,exist_ok=True)
    else:
        args.out_dir = os.path.join(args.base_output,f'exp_{formatted_time}','results')
        os.makedirs(args.out_dir,exist_ok=True)
        save_args_to_txt(args,
                        os.path.join(args.base_output,
                                    f'exp_{formatted_time}',
                                    'config.txt')
                        )       
        if args.save_flow:
            args.save_flow_path = os.path.join(args.base_output,f'exp_{formatted_time}','flow')
            os.makedirs(args.save_flow_path,exist_ok=True)
    
        
    
    for t_iter in range(args.total_iters):
        if t_iter==0:
            imgs_names = sorted(os.listdir(args.imgs_path))
        else:
            args.imgs_path = args.out_dir
            imgs_names = sorted(os.listdir(args.imgs_path))
        

        window_size = args.radius
        stride = args.radius - args.overlap
        H, W = None, None
        Imgs, regImgs = None, None
        
        input_transform = transforms.Compose([
            transforms.ToTensor(),               
        ])

        radius_start_idx = 0
        for i in tqdm(range(0, len(imgs_names) - window_size + 1, stride)):
            selected_imgs = imgs_names[i:i + window_size]

            img_tensors = []
            
            # ------------------------------------------
            if i==0:
                img = cv2.imread(os.path.join(args.imgs_path, imgs_names[0]), cv2.IMREAD_GRAYSCALE)  
                cv2.imwrite(os.path.join(args.out_dir, imgs_names[0]), img)
                H, W = img.shape
                
                for name in selected_imgs:
                    img_path = os.path.join(args.imgs_path, name)
                    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  
                    if img is None:
                        raise ValueError(f"Image not found or invalid: {img_path}")        
                    img_tensors.append(input_transform(img))
                    Imgs = torch.stack(img_tensors, dim=0) # [T, 1, H, W]
                    Imgs = Imgs.to(args.device)
            else:
                Imgs[:args.overlap] = regImgs[-args.overlap:]
                selected_imgs = selected_imgs[args.overlap:]
                for k, name in enumerate(selected_imgs):
                    img_path = os.path.join(args.imgs_path, name)
                    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  
                    if img is None:
                        raise ValueError(f"Image not found or invalid: {img_path}")        
                    Imgs[args.overlap+k:args.overlap+k+1] = input_transform(img)
            # -----------------------------------------------

            radius_start_idx = i
            Flows = getFlow(2.0*Imgs-1.0, traj_model)
            Trajs = getTraj(Flows, window_size, H, W, device=args.device) 
            
            sm_Trajs = smooth_gauss_seidel(Trajs, lambda_=args.gs_lamba_, niters=args.gs_niters)
                
            offset = Trajs - sm_Trajs
            offset[:,0,:,:] = (offset[:,0,:,:] / (H-1)) * 2   
            offset[:,1,:,:] = (offset[:,1,:,:] / (W-1)) * 2   

            grid_offset = getGridOffset(Trajs, offset) 
            for g_idx in range(grid_offset.size(0)):
                with torch.no_grad():
                    grid_offset[g_idx:g_idx+1] = smooth_flow(grid_offset[g_idx:g_idx+1])
            
            regImgs = getRegImgs(Imgs, grid_offset) # [T,1,H,W]
            
            if args.save_flow and t_iter==0:
                for j in range(grid_offset.size(0)):
                    flow_save_path = os.path.join(args.save_flow_path, 
                                                  imgs_names[radius_start_idx+j+1].replace(".png", ".pt"))
                    if not os.path.exists(flow_save_path):#[H,W,2]
                        torch.save(grid_offset[j].cpu(), flow_save_path)
                
            for j in range(regImgs.size(0)):
                cv2.imwrite(os.path.join(args.out_dir, imgs_names[radius_start_idx+j]), 
                            postprocessing(regImgs[j:j+1]))

        if radius_start_idx+window_size-1 != len(imgs_names)-1:
            last_num = len(imgs_names) - (radius_start_idx+window_size)
            done_num = window_size-last_num
            Imgs[:done_num] = regImgs[-done_num:]
            selected_imgs = imgs_names[-last_num:]
            for k, name in enumerate(selected_imgs):
                img_path = os.path.join(args.imgs_path, name)
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  
                if img is None:
                    raise ValueError(f"Image not found or invalid: {img_path}")        
                Imgs[done_num+k:done_num+k+1] = input_transform(img)
                
            Flows = getFlow(2.0*Imgs-1.0, traj_model)
            Trajs = getTraj(Flows, window_size, H, W, device=args.device) 

            sm_Trajs = smooth_gauss_seidel(Trajs, lambda_=args.gs_lamba_, niters=args.gs_niters)

            offset = Trajs - sm_Trajs 
            offset[:,0,:,:] = (offset[:,0,:,:] / (H-1)) * 2   
            offset[:,1,:,:] = (offset[:,1,:,:] / (W-1)) * 2 

            grid_offset = getGridOffset(Trajs, offset) # [T-1,2,H,W]
            for g_idx in range(grid_offset.size(0)):
                with torch.no_grad():
                    grid_offset[g_idx:g_idx+1] = smooth_flow(grid_offset[g_idx:g_idx+1])
            regImgs = getRegImgs(Imgs, grid_offset) # [T,1,H,W]
            
            if args.save_flow and t_iter==0:
                for j in range(grid_offset.size(0)):
                    flow_save_path = os.path.join(args.save_flow_path, 
                                                  imgs_names[j-window_size+1].replace(".png", ".pt"))
                    if not os.path.exists(flow_save_path):#[H,W,2]
                        torch.save(grid_offset[j].cpu(), flow_save_path)

            for j in range(regImgs.size(0)):
                cv2.imwrite(os.path.join(args.out_dir, imgs_names[j-window_size]), 
                            postprocessing(regImgs[j:j+1]))
            

        
            
if __name__ == '__main__':
    main()
    