import numpy as np
import gymnasium as gym
import pandas as pd

from causalchamber.simulators.lt.image.decoder import DecoderSimple

from relign.generator import BaseEnvironment as RelignEnv
from relign.generator import LensEnv as RelignLensEnv
from relign.generator import GymnasiumEnv as RelignGymEnv


def rmse(arr0, arr1):
    return np.sqrt(np.mean((arr0 - arr1) ** 2))


def block_diag_repeat(A: np.ndarray, n: int) -> np.ndarray:
    """
    Build an n×n block-diagonal matrix that repeats a given 2×2 block A.
    If n is odd, the remaining 1×1 block is set to 1.

    Examples:
      n=6 -> diag(A, A, A)
      n=5 -> diag(A, A, 1)
      n=4 -> diag(A, A)
      n=3 -> diag(A, 1)
    """
    A = np.asarray(A)
    if A.shape != (2, 2):
        raise ValueError("A must be 2x2")
    M = np.eye(n, dtype=A.dtype)
    for i in range(0, n - 1, 2):
        M[i : i + 2, i : i + 2] = A
    return M


class BaseEnvironment(RelignEnv):
    N_CHANNELS: int

    def get_optimal_move(self):
        return np.linalg.inv(self.W).dot(self.r_optimal - self.r)


class LensEnv(BaseEnvironment, RelignLensEnv):
    N_CHANNELS: int = 1

    def reset(self, keep_lenses=False):
        if keep_lenses:
            return super(RelignLensEnv, self).reset()

        else:
            return super().reset()


class LensPositioning(RelignGymEnv):
    def __init__(
        self,
        env_cls=LensEnv,
        distortion="movement",
        *args,
        **kwargs,
    ):
        self.env = env_cls(*args, **kwargs)

        assert hasattr(self.env, 'r_optimal')
        assert hasattr(self.env, 'r')

        self.observation_space = gym.spaces.Box(
            low=0,
            high=3,
            shape=(self.env.N_CHANNELS, self.env.width, self.env.height),
            dtype=np.float32,
        )
        self.action_space = gym.spaces.Box(
            low=-0.2, high=0.2, shape=(self.env.n_actions,), dtype=np.float32
        )

        if distortion != "default":
            self.env._distort_movement = self._make_distortion_fn(distortion)

    @property
    def obs_shape(self):
        return (self.env.N_CHANNELS, self.env.width, self.env.height)

    def reset(self, seed=None, options=None, **kwargs):
        gym.Env.reset(self, seed=seed)
        img = self.env.reset(**kwargs)
        return self._preprocess_obs(img), {}

    def get_position_diff_to_optimum(self):
        return self.env.r_optimal - self.env.r

    def _make_distortion_fn(self, distortion="movement"):
        if distortion == "movement":
            W = np.eye(self.env.n_actions)

            noise = self.env.random.normal(
                size=(self.env.n_actions, self.env.n_actions),
                scale=self.env.noise_movement,
            )
            W = W + noise
            if np.linalg.det(W) == 0:
                raise ValueError("Matrix not singular")

            def distort_movements(r, a):  # LenEvn the return must the new r
                return W.dot(a.T).T + r

            return distort_movements
        else:
            return lambda r, a: r + a


class PositionOnly(gym.Env):
    N_ACTIONS_MAX: int = 5

    def __init__(
        self,
        seed=None,
        n_actions=5,
        max_episode_steps=200,
        score_goal_threshold=0.02,
        noise_movement=0.05,
        distortion="movement",
        clipping="soft",
        demo_image=False,
    ):
        if n_actions > self.N_ACTIONS_MAX:
            raise ValueError(f"n_actions must be <= {self.N_ACTIONS_MAX}")
        self.demo_image = demo_image
        self.max_episode_steps = max_episode_steps
        self.score_goal_threshold = score_goal_threshold
        self.random = np.random.default_rng(seed)
        self.n_actions = n_actions
        self.r = np.zeros(n_actions)
        self.noise_movement = noise_movement

        self.distortion = distortion

        self.r_optimal = 0.5 * np.ones(self.n_actions)

        self.observation_space = gym.spaces.Box(
            low=0, high=1, shape=(self.n_actions,), dtype=np.float32
        )
        self.action_space = gym.spaces.Box(
            low=-0.2, high=0.2, shape=(self.n_actions,), dtype=np.float32
        )
        if clipping not in ['soft', 'hard']:
            raise ValueError('Invalid clipping')
        self.clipping = clipping

    def _make_distortion_fn(self):
        if self.distortion == "movement":
            W = np.eye(self.n_actions)

            noise = self.random.normal(
                size=(self.n_actions, self.n_actions),
                scale=self.noise_movement,
            )
            W = W + noise
            if np.linalg.det(W) == 0:
                raise ValueError("Matrix not singular")

            def distort_movements(_, actions):
                return W.dot(actions.T).T

            return distort_movements

        if self.distortion == "rotation":
            alpha = self.random.normal(scale=self.noise_movement)

            W = np.array(
                [
                    [np.cos(alpha), -np.sin(alpha)],
                    [np.sin(alpha), np.cos(alpha)],
                ]
            )

            W = block_diag_repeat(W, self.n_actions)

            def distort_rotational(_, actions):
                return W.dot(actions.T)


            return distort_rotational

        if self.distortion == "scale":

            def scale_distorted(state, actions):
                scale = np.linalg.norm(state - self.r_optimal)
                action_length = np.linalg.norm(actions)
                # Clamp the scale factor to avoid instability
                scale = np.clip(scale, 0.7, 0.2 / (action_length + 0.001))
                return actions * scale

            return scale_distorted

        if self.distortion == "regional_rotation":
            angles = self.random.normal(
                loc=np.array([-0.3, 0.6, -0.3, 0.6]),
                scale=0.5 * self.noise_movement,
            )

            Ws = [
                block_diag_repeat(
                    np.array(
                        [
                            [np.cos(alpha), -np.sin(alpha)],
                            [np.sin(alpha), np.cos(alpha)],
                        ]
                    ),
                    self.n_actions,
                )
                for alpha in angles
            ]

            def regional_rotation(state, actions):
                # Compute region ID based on state position (e.g. 2x2 grid)
                grid_x = (state[0] > 0.5).astype(int)
                grid_y = (state[1] > 0.5).astype(int)
                region_id = 2 * grid_y + grid_x  # 0 to 3

                # Choise angle from region
                W = Ws[region_id]

                return W.dot(actions.T)

            return regional_rotation

        if self.distortion == "sinusodial":
            scale = self.random.uniform(0, self.noise_movement, size=self.n_actions)

            def sinusoidal_distorted(state, actions):
                warp = np.sin(2 * np.pi * state) * np.cos(2 * np.pi * state)
                offset = 2 * scale * warp * np.linalg.norm(actions)

                return actions + offset

            return sinusoidal_distorted


        if self.distortion == "sqrt":
            W = np.eye(self.n_actions)

            noise = self.random.normal(
                size=(self.n_actions, self.n_actions),
                scale=self.noise_movement,
            )
            W = W + noise
            if np.linalg.det(W) == 0:
                raise ValueError("Matrix not singular")

            def distort_sqrt(_, actions):
                return 3*W.dot(actions.T).T*np.sqrt(np.linalg.norm(actions))

            return distort_sqrt


        raise ValueError(f"Unknown distortion type: {self.distortion}")

    def _make_image(self):
        X = np.outer(self.r, self.r)
        r_idx = (np.linspace(0, X.shape[0] - 0.01, 100)).astype(int)
        c_idx = (np.linspace(0, X.shape[1] - 0.01, 100)).astype(int)
        img = X[r_idx][:, c_idx]
        return img[np.newaxis]

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.random = np.random.default_rng(seed)
        self._elapsed_steps = 0
        self._scores = []

        self.distort_fn = self._make_distortion_fn()

        self.r_offset = self.random.uniform(-0.5, 0.5, size=self.n_actions)

        self.r = np.clip(
            self.r_optimal + self.r_offset,
            np.zeros(self.n_actions),
            np.ones(self.n_actions),
        )
        self._scores.append(self.compute_distance())
        if self.demo_image:
            return self._make_image(), {}
        return self.r, {}

    def compute_distance(self):
        return np.linalg.norm(self.get_position_diff_to_optimum())

    def get_position_diff_to_optimum(self):
        return self.r_optimal - self.r

    def _update(self, a):
        if not a.shape[0] == self.n_actions:
            raise ValueError("moving vector must have same dimension as `Environment.n_actions`")

        r_proposed = self.distort_fn(self.r, a) + self.r

        if np.all((r_proposed >= 0) & (r_proposed <= 1)):
            self.r = r_proposed
        else:
            self.r = self._clip(r_proposed)

    def _clip(self, r):
        if self.clipping == "soft":
            return np.clip(r, self.action_space.low, self.action_space.high)
        else:
            """
            Clips the proposed direction such that position lies within unit box
            """

            a = r - self.r

            intersections = []

            normals = np.concatenate([np.eye(self.n_actions), -np.eye(self.n_actions)])
            intercepts = np.concatenate([np.ones(self.n_actions), np.zeros(self.n_actions)])

            for n, b in zip(normals, intercepts):
                t = (b - n.dot(self.r)) / (n.dot(a))
                intersections.append(t)

            intersections = np.array(intersections)

            # Get smallest non-negative element in intersections
            tmin = intersections[intersections >= -1e-10].min()
            return self.r + tmin * a

    def positioning_reward(self, score, alpha=100.0):
        if score <= self.score_goal_threshold:
            return 0.0
        else:
            return -np.exp(-alpha * (score - self.score_goal_threshold))

    def step(self, action):
        if isinstance(action, list):
            action = np.array(action)

        self._elapsed_steps += 1
        self._update(action)

        score = self.compute_distance()
        self._scores.append(score)

        truncated = self._elapsed_steps >= self.max_episode_steps
        terminated = score < self.score_goal_threshold
        reward = -score

        if self.demo_image:
            return self._make_image(), reward, truncated, terminated, {}
        return self.r, reward, truncated, terminated, {}


class LightTunnel(PositionOnly):
    N_ACTIONS_MAX: int = 5

    def __init__(self, n_actions=5, distortion="movement", *args, **kwargs):
        self.simulator = DecoderSimple(root='/tmp/')
        super().__init__(*args, n_actions=n_actions, distortion=distortion, **kwargs)

        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(3, 64, 64), dtype=np.float32)

        self.r_center = 0.5 * np.ones(self.N_ACTIONS_MAX)

    def _make_image(self, r):

        # create full array with 0.5
        r_pad = self.r_center.copy()

        r_pad[: len(r)] = r

        inputs = pd.DataFrame(
            {
                'pol_1': [-180 + 360 * r_pad[0]],
                'pol_2': [-180 + 360 * r_pad[1]],
                'red': [255 * r_pad[2]],
                'green': [255 * r_pad[3]],
                'blue': [255 * r_pad[4]],
            }
        )
        img = self.simulator.simulate_from_inputs(inputs)[0].astype(np.float32)
        if img.shape[2] == 3:
            img = np.transpose(img, (2, 0, 1))
        return img

    def reset(self, seed=None, options=None):
        # Set optimal position
        self.r_optimal = self.random.uniform(0.2, 0.8, size=self.n_actions)
        r, info = super().reset(seed=seed, options=options)
        img = self._make_image(r)

        self.img_target = self._make_image(self.r_optimal)

        img = 0.5 + 0.5 * (img - self.img_target)
        return img, info

    def step(self, action):
        r, reward, truncated, terminated, info = super().step(action)

        img = self._make_image(r)
        #img = 0.5 + 0.5 * (img - self.img_target)
        img = img - self.img_target
        return img, reward, truncated, terminated, info
