import os, math
import numpy as np
import scipy.signal
from typing import List, Optional

from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib import devrf
from .masked_adam import MaskedAdam

import lib_extra.metrics_jax as metrics_jax 
import matplotlib.pyplot as plt
from skimage.color import label2rgb
import cv2

from typing import Sequence, Union
Array = Union[np.ndarray, torch.Tensor]

import pdb



''' Misc
'''
mse2psnr = lambda x : -10. * torch.log10(x)
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)

# function used in static_DVGO
def create_optimizer_or_freeze_model2(model, cfg_train, global_step):
    decay_steps = cfg_train.lrate_decay * 1000
    decay_factor = 0.1 ** (global_step/decay_steps)

    param_group = []
    for k in cfg_train.keys():
        if not k.startswith('lrate_'):
            continue
        k = k[len('lrate_'):]

        if not hasattr(model, k):
            continue

        param = getattr(model, k)
        if param is None:
            print(f'create_optimizer_or_freeze_model: param {k} not exist')
            continue

        lr = getattr(cfg_train, f'lrate_{k}') * decay_factor
        if lr > 0:
            print(f'create_optimizer_or_freeze_model: param {k} lr {lr}')
            if isinstance(param, nn.Module):
                param = param.parameters()
            param_group.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
        else:
            print(f'create_optimizer_or_freeze_model: param {k} freeze')
            param.requires_grad = False
    return MaskedAdam(param_group)



def create_optimizer_or_freeze_model_2stage(model, cfg_train, global_step):
    group1 = cfg_train.group1
    group2 = cfg_train.group2
    decay_steps = cfg_train.lrate_decay * 1000
    decay_factor = 0.1 ** (global_step/decay_steps)

    param_group1 = []
    param_group2 = []
    for k in cfg_train.keys():
        if not k.startswith('lrate_'):
            continue
        k = k[len('lrate_'):]

        if not hasattr(model, k):
            continue


        param = getattr(model, k)
        if param is None:
            print(f'create_optimizer_or_freeze_model: param {k} not exist')
            continue

        lr = getattr(cfg_train, f'lrate_{k}') * decay_factor
        if lr > 0:
            print(f'create_optimizer_or_freeze_model: param {k} lr {lr}')
            if isinstance(param, nn.Module):
                param = param.parameters()
            if k in group1:
                param_group1.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
            elif k in group2:
                param_group2.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
            else:
                raise NotImplementedError
        else:
            print(f'create_optimizer_or_freeze_model: param {k} freeze')
            param.requires_grad = False
    print(param_group1,param_group2)
    return MaskedAdam(param_group1), MaskedAdam(param_group2)





# Create optimizer for different motion timesteps
def create_optimizer_or_freeze_model(model, cfg_train, timesteps, current_step, decay_factor):
    # pdb.set_trace()
    param_group = []
    for k in cfg_train.keys():
        if not k.startswith('lrate_'):
            continue
        k = k[len('lrate_'):]

        if not hasattr(model, k):
            continue

        param = getattr(model, k)
        if param is None:
            print(f'create_optimizer_or_freeze_model: param {k} not exist')
            continue
        
        # pdb.set_trace()
        lr = getattr(cfg_train, f'lrate_{k}') * decay_factor
        if lr > 0:
            print(f'create_optimizer_or_freeze_model: param {k} lr {lr}')
            if isinstance(param, nn.Module):
                param = param.parameters()
            if k == 'motion_list':
                param_group.append({'params': param[current_step], 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
            elif k == 'fwdmotion_list':
                param_group.append({'params': param[current_step], 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})       
            else:
                param_group.append({'params': param, 'lr': lr, 'skip_zero_grad': (k in cfg_train.skip_zero_grad_fields)})
        else:
            print(f'create_optimizer_or_freeze_model: param {k} freeze')
            param.requires_grad = False
    return MaskedAdam(param_group)


''' Checkpoint utils
'''
def load_checkpoint(model, optimizer, ckpt_path, no_reload_optimizer):
    ckpt = torch.load(ckpt_path)
    start = ckpt['global_step']
    model.load_state_dict(ckpt['model_state_dict'])
    if not no_reload_optimizer:
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    return model, optimizer, start

def load_model(model_class, ckpt_path):
    ckpt = torch.load(ckpt_path)
    model = model_class(**ckpt['model_kwargs'])
    model.load_state_dict(ckpt['model_state_dict'], strict=False)
    model.motion_list = ckpt['motion_list']
    return model

def load_model_ours(model_class, ckpt_path):
    # pdb.set_trace()
    ckpt = torch.load(ckpt_path)
    model = model_class(**ckpt['model_kwargs'])
    model.load_state_dict(ckpt['model_state_dict'], strict=False)
    # model.motion_list = ckpt['motion_list']
    return model

def load_staticmodel(model_class, num_voxels_motion, timesteps, motion_dim, warp_ray, ckpt_path, world_motion_bound_scale, mpi_depth_motion):
    # pdb.set_trace()
    ckpt = torch.load(ckpt_path)
    if model_class == devrf.DeVRF:
        model = model_class(num_voxels_motion=num_voxels_motion, timesteps=timesteps, motion_dim=motion_dim, warp_ray=warp_ray, world_motion_bound_scale=world_motion_bound_scale, **ckpt['model_kwargs'])
        model.load_state_dict(ckpt['model_state_dict'], strict=False)
    else:
        model = model_class(num_voxels_motion=num_voxels_motion, timesteps=timesteps, motion_dim=motion_dim, warp_ray=warp_ray, world_motion_bound_scale=world_motion_bound_scale, mpi_depth_motion=mpi_depth_motion, **ckpt['model_kwargs'])
        model.load_state_dict(ckpt['model_state_dict'], strict=False)        
    return model


def load_pretrained_model(model_class, num_voxels_motion, timesteps, warp_ray, ckpt_path, dynamics_path, sa_path, world_motion_bound_scale, mpi_depth_motion, kwargs, xyz_min, xyz_max):
    def delete_pre(str, order):
        l = str.split('.')
        used_l = l[order:]
        concat = '.'
        l_return = concat.join(used_l)
        return l_return
    

    model = model_class(num_voxels_motion=num_voxels_motion, timesteps=timesteps, warp_ray=warp_ray, world_motion_bound_scale=world_motion_bound_scale, **kwargs)
    
    
    if dynamics_path != None:
        # pdb.set_trace()
        ckpt = torch.load(dynamics_path, map_location='cuda:0')['model']
        # only load transformer parameters
        model_dict = model.predictor.state_dict()
        state_dict = {delete_pre(k, 2):v for k,v in ckpt.items() if delete_pre(k, 2) in model_dict.keys()}
        model.predictor.load_state_dict(state_dict)
        print("Load {} successfully!".format(dynamics_path))
    
    return model


def load_pretrained_model_v2(model_class, num_voxels_motion, timesteps, warp_ray, ckpt_path, dynamics_path, sa_path, world_motion_bound_scale, mpi_depth_motion, kwargs, xyz_min, xyz_max):
    def delete_pre(str, order):
        l = str.split('.')
        used_l = l[order:]
        concat = '.'
        l_return = concat.join(used_l)
        return l_return
    
    ckpt = torch.load(ckpt_path)
    kwargs['xyz_min'] = ckpt['model_kwargs']['xyz_min']
    kwargs['xyz_max'] = ckpt['model_kwargs']['xyz_max']
   
    model = model_class(num_voxels_motion=num_voxels_motion, timesteps=timesteps, warp_ray=warp_ray, world_motion_bound_scale=world_motion_bound_scale, **kwargs)
    model.load_state_dict(ckpt['model_state_dict'], strict=False)
    
    print("Load {} successfully!".format(ckpt_path))
    
    #pdb.set_trace()
    if dynamics_path != None:
        ckpt = torch.load(dynamics_path, map_location='cuda:0')['model']
        model_dict = model.predictor.state_dict()
        model_dict.update(model.slot_attention.state_dict())
        state_dict = {delete_pre(k, 2):v for k,v in ckpt.items() if delete_pre(k, 2) in model_dict.keys()}
        model.predictor.load_state_dict(state_dict, strict=False)
        model.slot_attention.load_state_dict(state_dict, strict=False)

        # pdb.set_trace()
        model_dict = model.encoder.state_dict()
        state_dict = {delete_pre(k, 1):v for k,v in ckpt.items() if k.startswith('encoder') and delete_pre(k,1) in model_dict.keys()}
        model.encoder.load_state_dict(state_dict, strict=False)
        print("Load {} successfully!".format(dynamics_path))

   
    
    return model



''' Evaluation metrics (ssim, lpips)
'''
def rgb_ssim(img0, img1, max_val,
             filter_size=11,
             filter_sigma=1.5,
             k1=0.01,
             k2=0.03,
             return_map=False):
    
    assert len(img0.shape) == 3
    assert img0.shape[-1] == 3
    assert img0.shape == img1.shape

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = np.maximum(0., sigma00)
    sigma11 = np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(
        np.sqrt(sigma00 * sigma11), np.abs(sigma01))
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = np.mean(ssim_map)
    return ssim_map if return_map else ssim


__LPIPS__ = {}
def init_lpips(net_name, device):
    assert net_name in ['alex', 'vgg']
    import lpips
    print(f'init_lpips: lpips_{net_name}')
    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)

def rgb_lpips(np_gt, np_im, net_name, device):
    if net_name not in __LPIPS__:
        __LPIPS__[net_name] = init_lpips(net_name, device)
    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
    return __LPIPS__[net_name](gt, im, normalize=True).item()


class ARI(nn.Module):
  """ARI."""

  def forward(self, pr_seg, gt_seg, slots):
    #video, segmentations, padding_mask, mask = batch
        
    
    input_pad =np.ones(pr_seg.shape).astype(np.int64)
    #print(pr_seg.max(),gt_seg.max())
    gt_instance = gt_seg.max()+1
    pr_instance = pr_seg.max()+1
   

    # ari_bg = metrics.Ari.from_model_output(
    ari_bg =  metrics_jax.Ari.from_model_output(
      predicted_segmentations=pr_seg, ground_truth_segmentations=gt_seg,
      padding_mask=input_pad,
      ground_truth_max_num_instances=gt_instance,
      predicted_max_num_instances=pr_instance,
      ignore_background=False)
    # ari_nobg = metrics.Ari.from_model_output(
    ari_nobg =  metrics_jax.Ari.from_model_output(
      predicted_segmentations=pr_seg, ground_truth_segmentations=gt_seg,
      padding_mask=input_pad, 
      ground_truth_max_num_instances=gt_instance,
      predicted_max_num_instances=pr_instance,
      ignore_background=True)
    
    return ari_bg, ari_nobg

def plot_image(ax, img, label=None):
		ax.imshow(img)
		ax.axis('off')
		ax.set_xticks([])
		ax.set_yticks([])
		if label:
			# ax.set_title(label, fontsize=3, y=-21)
			ax.set_xlabel(label, fontsize=3)
			ax.axis('on')

def vis_seg(vid, pr_masks, gt_masks, savedir): # [seq, H, W]
    '''
    args:
    vid: (L, H, W, C)
    gt_mask: (L, H, W, C)
    '''

    # pdb.set_trace()
    savedir = os.path.join(savedir, 'seg')
    os.makedirs(savedir, exist_ok=True)

    T = len(vid)

    for i in range(T):
        plt.close()
        fig, ax = plt.subplots(1, 3, dpi=400)

        # for t in range(T):
        # pdb.set_trace()
        vidgrey = cv2.cvtColor(vid[i], cv2.COLOR_RGB2GRAY)[...,None]
        gt_seg = label2rgb(gt_masks[i], vidgrey)
        pred_seg = label2rgb(pr_masks[i], vidgrey)

        plot_image(ax[0], vid[i], 'original')
        plot_image(ax[1], gt_seg[:,:,0,:], 'gt_seg')
        plot_image(ax[2], pred_seg[:,:,0,:], 'pred_seg')

        plt.savefig(os.path.join(savedir, str(i).zfill(3)+'.png'))

        seperate_save_dir1 = os.path.join(savedir, 'seperate_gt')
        seperate_save_dir2 = os.path.join(savedir, 'seperate_pr')
        os.makedirs(seperate_save_dir1, exist_ok=True)
        os.makedirs(seperate_save_dir2, exist_ok=True)

        # pdb.set_trace()
        cv2.imwrite(os.path.join(seperate_save_dir1, str(i).zfill(3)+'_gt_seg.png'), (gt_seg[:, :, 0, :]*255))
        cv2.imwrite(os.path.join(seperate_save_dir2, str(i).zfill(3)+'_pred_seg.png'), (pred_seg[:, :, 0, :]*255))
        
        
def create_gradient_grid(
    samples_per_dim: Sequence[int], value_range: Sequence[float] = (-1.0, 1.0)
    ) -> Array:
    """Creates a tensor with equidistant entries from -1 to +1 in each dim
    
    Args:
        samples_per_dim: Number of points to have along each dimension.
        value_range: In each dimension, points will go from range[0] to range[1]
    
    Returns:
        A tensor of shape [samples_per_dim] + [len(samples_per_dim)].
    """

    s = [np.linspace(value_range[0], value_range[1], n) for n in samples_per_dim]
    pe = np.stack(np.meshgrid(*s, sparse=False, indexing="ij"), axis=-1)
    return np.array(pe)

def convert_to_fourier_features(inputs: Array, basis_degree: int) -> Array:
    """Convert inputs to Fourier features, e.g. for positional encoding."""

    # inputs.shape = (..., n_dims).
    # inputs should be in range [-pi, pi] or [0, 2pi].
    n_dims = inputs.shape[-1]

    # Generate frequency basis
    freq_basis = np.concatenate( # shape = (n_dims, n_dims * basis_degree)
        [2**i * np.eye(n_dims) for i in range(basis_degree)], 1)
    
    # x.shape = (..., n_dims * basis_degree)
    x = inputs @ freq_basis # Project inputs onto frequency basis.

    # Obtain Fourier feaures as [sin(x), cos(x)] = [sin(x), sin(x + 0.5 * pi)].
    return np.sin(np.concatenate([x, x + 0.5 * np.pi], axis=-1))
