import os
import json

import numpy as np
import torch
import yaml

from stable_baselines3.common.running_mean_std import RunningMeanStd
from dlp2.models import ObjectDLP
from bsa.checkpoint import Checkpoint
from bsa.encoder import BSAEncoder, OSRTEncoder

"""
Misc
"""
def check_config(config, isaac_env_cfg=None, policy_config=None):
    method = config['Model']['method']
    obs_type = config['Model']['obsType']
    obs_mode = config['Model']['obsMode']

    assert method in ['BT', 'EIT']
    assert obs_type in ['Image']

    if obs_type == 'Image':
        if method == 'EIT':
            assert obs_mode in ['dlp', '3d_slot', '3d_block']
        if method == 'BT':
            assert obs_mode in ['3d_block']


def get_run_name(config, isaac_env_cfg, seed):
    name = f"{isaac_env_cfg['env']['numObjects']}O"

    if config['Model']['method'] == 'EIT':
        name += "_EIT"
    if config['Model']['method'] == 'BT':
        name += "_BT"

    if config['Model']['obsMode'] == 'dlp':
        name += "_DLP"
    if config['Model']['obsMode'] == '3d_slot':
        name += "_SLOT"
    if config['Model']['obsMode'] == '3d_block':
        name += "_BLOCK"

    name += f"_{config['Model']['numViews']}views"
    name += f"_{seed}"

    return name


"""
Logging
"""
def compute_gradients(parameters):
    total_gradient_norm = None
    for p in parameters:
        # if p.grad is None:
        #     continue
        current = p.grad.data.norm(2) ** 2
        if total_gradient_norm is None:
            total_gradient_norm = current
        else:
            total_gradient_norm += current
    return total_gradient_norm ** 0.5


def compute_params(parameters):
    total_param_norm = None
    for p in parameters:
        current = p.data.norm(2) ** 2
        if total_param_norm is None:
            total_param_norm = current
        else:
            total_param_norm += current
    return total_param_norm ** 0.5


def get_max_param(parameters):
    max_p = 0
    for p in parameters:
        current = p.data.abs().max()
        if current > max_p:
            max_p = current
    return max_p


"""
Pretrained Representation
"""
def load_pretrained_rep_model(dir_path, model_type='dlp'):

    if model_type not in ['dlp', '3d_slot', '3d_block']:
        return None

    if model_type == 'dlp':
        print("\nLoading pretrained DLP...")
        # load config
        conf_path = os.path.join(dir_path, 'hparams.json')
        with open(conf_path, 'r') as f:
            config = json.load(f)

        ckpt_path = os.path.join(dir_path, f'{model_type}_panda_push.pth')
        # initialize model
        model = ObjectDLP(cdim=config['cdim'], enc_channels=config['enc_channels'],
                          prior_channels=config['prior_channels'],
                          image_size=config['image_size'], n_kp=config['n_kp'],
                          learned_feature_dim=config['learned_feature_dim'],
                          bg_learned_feature_dim=config['bg_learned_feature_dim'],
                          pad_mode=config['pad_mode'],
                          sigma=config['sigma'],
                          dropout=False, patch_size=config['patch_size'], n_kp_enc=config['n_kp_enc'],
                          n_kp_prior=config['n_kp_prior'], kp_range=config['kp_range'],
                          kp_activation=config['kp_activation'],
                          anchor_s=config['anchor_s'],
                          use_resblock=False,
                          scale_std=config['scale_std'],
                          offset_std=config['offset_std'], obj_on_alpha=config['obj_on_alpha'],
                          obj_on_beta=config['obj_on_beta'])
        # load model from checkpoint
        model.load_state_dict(torch.load(ckpt_path))

    elif model_type == '3d_slot':
        print("\nLoading pretrained Slot-Attention...")

        # load config
        conf_path = os.path.join(dir_path, 'config.yaml')
        with open(conf_path, 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader)

        # initialize model
        model = OSRTEncoder(**cfg['model']['encoder_kwargs'])

        checkpoint = Checkpoint(dir_path, encoder=model)
        load_iter = cfg['training']['load_iter']
        checkpoint.load(f'model_{load_iter}.pt')

    elif model_type == '3d_block':
        print("\nLoading pretrained Block-Slot Attention...")

        # load config
        conf_path = os.path.join(dir_path, 'config.yaml')
        with open(conf_path, 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader)

        # initialize model
        model = BSAEncoder(**cfg['model']['encoder_kwargs'])

        checkpoint = Checkpoint(dir_path, encoder=model)
        load_iter = cfg['training']['load_iter']
        checkpoint.load(f'model_{load_iter}.pt')

    else:
        raise NotImplementedError(f"Pretrained model type '{model_type}' is not supported")

    model.eval()
    model.requires_grad_(False)

    return model


def get_dlp_rep(dlp_output):
    pixel_xy = dlp_output['z']
    scale_xy = dlp_output['mu_scale']
    depth = dlp_output['mu_depth']
    visual_features = dlp_output['mu_features']
    transp = dlp_output['obj_on'].unsqueeze(dim=-1)
    rep = torch.cat((pixel_xy, scale_xy, depth, visual_features, transp), dim=-1)
    return rep

def get_camera_ray(camera_view_matrix, camera_intrinsic_matrix, H, W):
    """
    camera_view_matrix: (4, 4) numpy array (world-to-camera pose)
    camera_intrinsic_matrix: (3, 3) numpy array
    H, W: image height and width

    Returns:
        cam_pos: (3,) camera position in world coordinates
        rays_world: (H, W, 3) ray directions in world space, normalized
    """
    # Invert view matrix to get camera-to-world
    cam_to_world = np.linalg.inv(camera_view_matrix)

    # Extract camera position from the inverse view matrix
    cam_pos = cam_to_world[:3, 3]

    # Create meshgrid of pixel coordinates
    i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
    pixels = np.stack([i + 0.5, j + 0.5, np.ones_like(i)], axis=-1)  # shape (H, W, 3)

    # Apply inverse intrinsics to get direction in camera space
    inv_K = np.linalg.inv(camera_intrinsic_matrix)
    dirs_cam = pixels @ inv_K.T  # shape (H, W, 3)

    # Convert directions to world space
    R = cam_to_world[:3, :3]  # rotation part
    #dirs_world = dirs_cam @ R.T  # rotate each direction vector
    dirs_world = - (dirs_cam @ R.T)

    # Normalize rays
    dirs_world = dirs_world / np.linalg.norm(dirs_world, axis=-1, keepdims=True)  # (H, W, 3)

    return cam_pos, dirs_world.astype(np.float32)


"""
Agent
"""
def action_noise_schedule(sig_start, sig_end, init_episodes, ss_episodes, tot_episodes):
    noise_schedule = []
    if init_episodes > 0:
        init_sigmas = np.ones(init_episodes) * sig_start
        noise_schedule.extend(init_sigmas)
    linear_sch_sigmas = np.linspace(sig_start, sig_end, tot_episodes - init_episodes - ss_episodes)
    noise_schedule.extend(linear_sch_sigmas)
    if ss_episodes > 0:
        ss_sigmas = np.ones(ss_episodes) * sig_end
        noise_schedule.extend(ss_sigmas)
    return np.asarray(noise_schedule)


class RMSNormalizer:
    def __init__(self, epsilon=1e-6, shape=()):
        self.epsilon = epsilon
        self.rms = RunningMeanStd(epsilon=epsilon, shape=shape)

    def update(self, obs):
        self.rms.update(obs)

    def normalize(self, obs):
        if torch.is_tensor(obs):
            device = obs.device
            dtype = obs.dtype
            mean = torch.tensor(self.rms.mean, device=device, dtype=dtype)
            var = torch.tensor(self.rms.var, device=device, dtype=dtype)
            epsilon = torch.tensor(self.epsilon, device=device, dtype=dtype)
            return torch.clip((obs - mean) / torch.sqrt(var + epsilon), -5, 5).to(torch.float32)
        else:
            return np.clip((obs - self.rms.mean) / np.sqrt(self.rms.var + self.epsilon), -5, 5).astype(np.float32)

    def unnormalize(self, obs):
        if torch.is_tensor(obs):
            device = obs.device
            dtype = obs.dtype
            mean = torch.tensor(self.rms.mean, device=device, dtype=dtype)
            var = torch.tensor(self.rms.var, device=device, dtype=dtype)
            epsilon = torch.tensor(self.epsilon, device=device, dtype=dtype)
            return (obs * torch.sqrt(var + epsilon)) + mean
        else:
            return (obs * np.sqrt(self.rms.var + self.epsilon)) + self.rms.mean
