import numpy as np
import torch
import torch.nn.functional as F
import datetime
import sys
from loguru import logger

# generate a random orthogonal matrix
def random_orthor_mat(dim: int) -> np.ndarray :
    res, _, __ = np.linalg.svd(np.random.randn(dim, dim), full_matrices = False)
    return res

def perturb_matrix(mat: np.ndarray, perturb: int) -> np.ndarray :
    res = mat.copy()
    for i in range(mat.shape[1]) :
        for j in range(perturb) :
            res[np.random.randint(0, mat.shape[0]), i] += 20
    return res

# generate a cyclic permutation matrix given the order of the permutation
def random_matrix(n, m, r = None, perturb = 0, radius = 0, order = None) :
    rand_trans_operator = lambda x: x / np.sum(x, axis = 1, keepdims = True) 
    if r == None :
        res = np.random.rand(n, m)
        if perturb != 0 :
            res = perturb_matrix(res, perturb)
        if radius == 0 or n != m: pass
        else :
            ord = np.arange(n)
            if order is not None :
                for i in range(n) : ord[order[i]] = i
            for i in range(n) :
                for j in range(n) :
                    if ord[j] != (ord[i] + radius) % n : res[i, j] = 0
        return rand_trans_operator(res)
    else :
        res1 = np.random.rand(n, r)
        res2 = np.random.rand(r, m)
        if perturb != 0 :
            res1 = perturb_matrix(res1, perturb)
            res2 = perturb_matrix(res2, perturb)
        return rand_trans_operator(res1) @ rand_trans_operator(res2)
    
# transform cyclic permutation to transition matrix
def _perm2trans(perm, eps) :
    # perm shape: (action_n, state_dim)
    # return shape: (action_n, state_dim, state_dim)
    state_dim = perm.shape[1]
    action_n = perm.shape[0]

    id = np.zeros((action_n, state_dim), dtype=np.int16)
    for i in range(action_n) :
        for j in range(state_dim) : id[i, perm[i, j]] = j

    res = np.zeros((action_n, state_dim, state_dim))
    for j in range(action_n) :
        for i in range(state_dim) :
            res[j, i, perm[j, (id[j, i] + 1) % state_dim]] += 1 - eps
            res[j, i, perm[j, (id[j, i] + state_dim - 1) % state_dim]] += eps
    return res

class MultiDimStruct :
    def __init__(self, shape, discrete = -1) -> None:
        self.shape = tuple(shape)
        self.n = discrete

# meta class for environments
class Env :
    def __init__(self, num_envs, state_dims, obs_dims, action_n, horizon) -> None:
        self.num_envs = num_envs
        self.state_dim = state_dims
        self.obs_dim = obs_dims
        self.action_n = action_n
        self.horizon = horizon
        self.cur_step = np.zeros(self.num_envs)

        # emission_dim: (state_dim, action_n)
        self.emission = np.zeros((self.state_dim, self.action_n))
        # transition_dim: (state_dim, state_dim)
        self.trans = np.zeros((self.state_dim, self.state_dim))
        # state_dim: (num_envs, state_dim)
        self.state = np.zeros((self.num_envs, self.state_dim))
        # real_state_dim: (num_envs, )
        self.real_state = np.zeros((self.num_envs, ), dtype = np.int32)
    
    # reset the env
    def reset(self, state_init) :
        raise NotImplementedError
    
    # env transits to the next step
    def step(self, action) :
        raise NotImplementedError
    
    # change the horizon of the environment
    def change_horizon(self, target_horizon) : self.horizon = target_horizon

    # the observation distribution
    def emit_obs_distribution(self) :
        raise NotImplementedError

    # sample from the observation distribution
    def emit_obs(self, real) :
        raise NotImplementedError
    
    # returns the stage in CyclicHMM-HARD
    def prediction_stage(self) : return np.full((self.num_envs,), True)

    def close() :
        pass

# RanLDS
class LinearDynamicalSystem(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon) -> None:
        super().__init__(num_envs, state_dims, obs_dims, action_n, horizon)
        # shape: (state_dims, state_dims)
        sigmax, sigmay = 1.0, 1.0
        self.sigmax, self.sigmay = sigmax, sigmay
        # use a slightly recalse transition matrix to avoid numerical instability
        self.trans = (1 - args.eps) * np.array([random_orthor_mat(self.state_dim) for i in range(2)])
        assert args.real, "LinearDynamicalSystem only supports observation prediction currently"
        assert args.state_dim == args.action_n

        self.cur_step = 0

        # computing the Kalman filter
        self.mu = np.zeros((self.num_envs, self.state_dim))
        self.sigma = np.zeros((horizon + 5, self.state_dim, self.state_dim))
        self.L_matrix = np.zeros((horizon + 5, self.state_dim, self.state_dim))
        self.kalman_filter_var = np.zeros(horizon + 5)
        self.sigma[0] = np.eye(self.state_dim)
        self.L_matrix[0] = self.trans[1] @ np.linalg.inv(self.trans[1].T @ self.trans[1] + sigmay * sigmay * np.eye(self.state_dim))
        self.kalman_filter_var[0] = np.trace(self.trans[1].T @ self.trans[1] + np.eye(self.state_dim))
        for i in range(horizon + 3) :
            self.sigma[i + 1] = np.linalg.inv(self.trans[1].T @ self.sigma[i] @ self.trans[1] + sigmay * sigmay * np.eye(self.state_dim))
            self.sigma[i + 1] = self.sigma[i] @ self.trans[1] @ self.sigma[i + 1] @ self.trans[1].T @ self.sigma[i]
            self.sigma[i + 1] = sigmax * sigmax * np.eye(self.state_dim) + self.trans[0].T @ (self.sigma[i] - self.sigma[i + 1]) @ self.trans[0]
            self.kalman_filter_var[i + 1] = np.trace(self.trans[1].T @ self.sigma[i + 1] @ self.trans[1] + sigmay * sigmay * np.eye(self.state_dim))
            self.L_matrix[i + 1] = self.sigma[i + 1] @ self.trans[1] @ np.linalg.inv(self.trans[1].T @ self.sigma[i + 1] @ self.trans[1] + sigmay * sigmay * np.eye(self.state_dim))

        self.mu_upd_L = np.zeros((horizon + 5, self.state_dim, self.state_dim))
        self.mu_upd_L_mul = np.zeros((horizon + 5, self.num_envs, self.state_dim, self.state_dim))
        self.mu_upd_R = np.zeros((horizon + 5, self.state_dim, self.state_dim))
        self.mu_upd_R_mul = np.zeros((horizon + 5, self.num_envs, self.state_dim, self.state_dim))
        for i in range(horizon + 3) :
            self.mu_upd_L[i] = self.trans[0] - self.trans[1] @ self.L_matrix[i].T @ self.trans[0]
            self.mu_upd_R[i] = self.L_matrix[i].T @ self.trans[0]
            self.mu_upd_L_mul[i] = np.tile(self.mu_upd_L[i].T, (self.num_envs, 1, 1))
            self.mu_upd_R_mul[i] = np.tile(self.mu_upd_R[i].T, (self.num_envs, 1, 1))

        # the matrix A, B for the linear dynamical system
        self.trans0_mul = np.tile(self.trans[0].T, (self.num_envs, 1, 1))
        self.trans1_mul = np.tile(self.trans[1].T, (self.num_envs, 1, 1))    

    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.random.randn(self.num_envs, self.state_dim)
            self.mu = np.zeros((self.num_envs, self.state_dim))
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
            self.mu = np.tile(state_init, (self.num_envs, 1))
        self.cur_step = 0

        return self.state, {}
    
    def step(self, action) :
        assert action.shape == (self.num_envs, self.state_dim)
        self.state = np.squeeze(np.matmul(self.trans0_mul, np.expand_dims(self.state, axis = 2))) + self.sigmax * np.random.randn(self.num_envs, self.state_dim)
        # self.mu: belief state given by Kalman filter
        self.mu = np.squeeze(np.matmul(self.mu_upd_L_mul[self.cur_step], np.expand_dims(self.mu, axis = 2))
                                 + np.matmul(self.mu_upd_R_mul[self.cur_step], np.expand_dims(action, axis = 2))
                                )
        self.cur_step += 1

        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) :
        return np.squeeze(np.matmul(self.trans1_mul, np.expand_dims(self.mu, axis = 2)))
    
    def emit_obs(self, real) : 
        return np.squeeze(np.matmul(self.trans1_mul, np.expand_dims(self.state, axis = 2))) + self.sigmay * np.random.randn(self.num_envs, self.state_dim)

# MatMul and RanHMM instances
class MatPred(Env) :
    def __init__(self, args, num_envs, action_n, horizon, state_dims, obs_dims, mode = 'simplified', rank = None, perturb = 0) -> None:
        super().__init__(num_envs, state_dims, obs_dims, action_n, horizon)
        self.mode = mode
        if mode == 'simplified' : # MatMul
            self.trans = np.array([random_orthor_mat(self.state_dim) for i in range(action_n)])
        else : # HMM
            # shape: (state_dims, state_dims)
            self.trans = np.random.rand(self.state_dim, self.state_dim)
            self.trans /= np.sum(self.trans, axis = 1, keepdims = True)
            # shape: (state_dims, action_n)
            self.emission = np.random.rand(self.state_dim, self.action_n)
            self.emission /= np.sum(self.emission, axis = 1, keepdims = True)
            self.emit_temp = self.trans @ self.emission
    
    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1)) # shape: (num_envs, state_dims)
            self.real_state = np.zeros(self.num_envs, dtype = np.int32) # shape: (num_envs,)
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
            self.real_state = np.array([np.random.choice(self.state_dim, p = state_init) for k in range(self.num_envs)], dtype = np.int32)
        self.cur_step = np.zeros(self.num_envs, dtype = np.int32)
        return self.state, {}
    
    def step(self, action) :
        assert (action >= 0).all() and (action < self.action_n).all() # shape : (num_envs,)
        if self.mode == 'simplified' :
            self.state = np.squeeze(np.matmul(np.transpose(self.trans[action], axes = (0, 2, 1)), np.expand_dims(self.state, axis = 2)))
        else :
            # vectorize the step function to speed up
            def step_per_env(state : np.ndarray, action: np.int64) -> np.ndarray :
                res = (state.T @ self.trans) * self.emission[:, action]
                return res / np.sum(res)

            vec_step = np.vectorize(step_per_env, signature = '(n),()->(n)')       
            self.state = vec_step(self.state, action)

        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        return self.state

    def emit_obs(self, real) : # shape: (num_envs)
        if self.mode == 'simplified' :
            return np.random.randint(0, self.obs_dim, self.num_envs)
        else :
            return np.array([np.random.choice(self.action_n, p = self.state[k].T @ self.emit_temp) for k in range(self.num_envs)])

# Cyclic-DET
class CyclicGroup(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon, mod = False) -> None:
        super().__init__(num_envs, state_dims, obs_dims, action_n, horizon)
        if mod : # modular addition
            self.trans = np.zeros((action_n, self.state_dim, self.state_dim))
            ord = np.arange(self.state_dim)
            for k in range(action_n) :
                for i in range(self.state_dim) : ord[i] = (i + k) % self.state_dim
                self.trans[k] = random_matrix(self.state_dim, self.state_dim, perturb = 0, radius = 1, order = ord)
        else : # random permutation
            self.trans = np.array([random_matrix(self.state_dim, self.state_dim, perturb = 0, radius = 1, order = np.random.permutation(self.state_dim)) for k in range(action_n)])

    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1))
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
        self.cur_step = np.zeros(self.num_envs, dtype = np.int32)
        return self.state, {}
    
    def step(self, action) :
        self.state = np.squeeze(np.matmul(np.transpose(self.trans[action], axes = (0, 2, 1)), np.expand_dims(self.state, axis = 2)))
        self.state /= np.sum(self.state, axis = 1, keepdims = True)

        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        return self.state
    
    def emit_obs(self, real = False) : # shape: (num_envs)
        return np.random.randint(0, self.action_n, (self.num_envs,))

    def prediction_stage(self) : return np.full((self.num_envs,), True)

# Cyclic-Hard
class CyclicHMM(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon, alpha) : 
        super().__init__(num_envs, state_dims, obs_dims, action_n + state_dims + 1, horizon)
        # first obs_dim must be the observation indicator
        args.action_n = args.state_dim + args.action_n + 1
        self.alpha = alpha
        # cur_step = 0 : transition phase;
        # cur_step = 1 : indicating phase;
        # cur_step = 2 : prediction phase;
        
        # shape: (obs_dim, state_dims, state_dims), generate obs_dim random matrix of size state_dims x state_dims
        self.trans = np.array([random_matrix(self.state_dim, self.state_dim, perturb = 0, radius = 1, order = np.random.permutation(self.state_dim)) for k in range(self.obs_dim)])
        self.trans = np.concatenate((self.trans, np.expand_dims(np.eye(self.state_dim), axis = 0)), axis = 0)
        self.emission = np.zeros((2 + self.state_dim, self.action_n))
        self.emission[0] = np.concatenate((np.ones(self.obs_dim) / self.obs_dim, np.zeros(self.state_dim + 1)))
        self.emission[1:] = np.concatenate((np.zeros((self.state_dim + 1, self.obs_dim)), np.eye(self.state_dim + 1)), axis = 1)
    
    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1)) # shape: (num_envs, self.state_dim)
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
        self.cur_step = np.zeros((self.num_envs,), dtype = np.int32)
        return self.state, {}
 
    def step(self, action) :
        indicator_prob = np.random.rand(self.num_envs)
        # for envs in transition phase, not entering indicating phase
        transit_mask = np.where(np.logical_and(self.cur_step == 0, indicator_prob > self.alpha), 1, 0)
        action_mask = np.where(transit_mask == 0, self.obs_dim, action)
        self.state = np.squeeze(np.matmul(np.transpose(self.trans[action_mask], axes = (0, 2, 1)), np.expand_dims(self.state, axis = 2)))
        self.state /= np.sum(self.state, axis = 1, keepdims = True) # np.diag(1 / np.sum(self.state, axis = 1)) @ self.state
        transit_mask = 1 - transit_mask
        self.cur_step = (self.cur_step + transit_mask) % 3
        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        # first obs_dim is to indicate the effective transition matrix at current step, the next single dim to indicate the end of permutation, the last state_dim to indicate the final state
        mask = np.where(self.cur_step <= 1, self.cur_step, np.argmax(self.state, axis = 1) + 2)
        return self.emission[mask]

    def emit_obs(self, real = False) : # shape: (num_envs)
        res = np.where(self.cur_step == 0, np.random.randint(0, self.obs_dim, (self.num_envs,)), self.obs_dim)
        return np.where(np.logical_and(res == self.obs_dim, self.cur_step == 2), np.argmax(self.state, axis = 1) + self.obs_dim + 1, res)

    def prediction_stage(self) : return self.cur_step == 2

# Grid Group
class Grid(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon) -> None:
        super().__init__(num_envs, state_dims, obs_dims, action_n, horizon)
        assert self.obs_dim == 2 and self.action_n == 2

        self.trans = np.zeros((self.action_n, self.state_dim, self.state_dim))
        self.trans[0, 0, 0] = 1
        self.trans[1, self.state_dim - 1, self.state_dim - 1] = 1
        for i in range(self.state_dim - 1) :
            self.trans[0, i + 1, i] = 1
            self.trans[1, i, i + 1] = 1
        
    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1))
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
        self.cur_step = np.zeros(self.num_envs, dtype = np.int32)
        return self.state, {}
    
    def step(self, action) :
        self.state = np.squeeze(np.matmul(np.transpose(self.trans[action], axes = (0, 2, 1)), np.expand_dims(self.state, axis = 2)))
        self.state /= np.sum(self.state, axis = 1, keepdims = True)
    
        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        return self.state
    
    def emit_obs(self, real = False) : # shape: (num_envs)
        return np.random.randint(0, self.action_n, (self.num_envs,))

    def prediction_stage(self) : return np.full((self.num_envs,), True)

# Dihedral Group
class Dihedral(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon) -> None:
        super().__init__(num_envs, state_dims, obs_dims, action_n, horizon)
        assert self.obs_dim == 2 and self.action_n == 2
        assert self.state_dim % 2 == 0
        self.state_one_side = self.state_dim // 2
        self.trans = np.zeros((self.action_n, self.state_dim, self.state_dim))
        for i in range(self.state_one_side) :
            self.trans[0, i, (i + 1) % self.state_one_side] = 1
            self.trans[0, i + self.state_one_side, (i + self.state_one_side - 1) % self.state_one_side + self.state_one_side] = 1
            self.trans[1, i, i + self.state_one_side] = 1
            self.trans[1, i + self.state_one_side, i] = 1
    
    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1))
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
        self.cur_step = np.zeros(self.num_envs, dtype = np.int32)
        return self.state, {}
    
    def step(self, action) :
        self.state = np.squeeze(np.matmul(np.transpose(self.trans[action], axes = (0, 2, 1)), np.expand_dims(self.state, axis = 2)))
        self.state /= np.sum(self.state, axis = 1, keepdims = True)
    
        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        return self.state
    
    def emit_obs(self, real = False) : # shape: (num_envs)
        return np.random.randint(0, self.action_n, (self.num_envs,))

    def prediction_stage(self) : return np.full((self.num_envs,), True)

# Cyclic-RND
class CyclicRealHMM(Env) :
    def __init__(self, args, num_envs, state_dims, obs_dims, action_n, horizon, eps = 0.0) -> None:
        super().__init__(num_envs, state_dims * action_n, obs_dims, action_n, horizon)
        args.state_dim = args.state_dim * args.action_n
        self.eps = eps
        ord = np.array([np.random.permutation(state_dims) for _ in range(self.action_n)])
            
        mdp_trans = _perm2trans(ord, eps)
        assert mdp_trans.shape == (self.action_n, state_dims, state_dims)
        #print(mdp_trans)
        index = lambda x, y: x * action_n + y
        for s_index in range(self.state_dim) :
            s, a = s_index // action_n, s_index % action_n
            self.emission[s_index, a] = 1
            for s_prime in range(state_dims) :
                if mdp_trans[a, s, s_prime] == 0 : continue
                for a_prime in range(action_n) :
                    self.trans[s_index, index(s_prime, a_prime)] = mdp_trans[a, s, s_prime] / action_n
        
        assert (np.abs(np.sum(self.trans, axis = 1) - 1) < 1e-8).all()
        assert (np.abs(np.sum(self.emission, axis = 1) - 1) < 1e-8).all()
        assert args.real == False
        self.emit_temp = self.trans @ self.emission # next obersevation distribution
 
    def reset(self, state_init = None) :
        if state_init is None :
            self.state = np.tile(np.eye(self.state_dim)[0], (self.num_envs, 1)) # shape: (num_envs, state_dims)
            self.real_state = np.zeros(self.num_envs, dtype = np.int32) # shape: (num_envs,)
        else :
            self.state = np.tile(state_init, (self.num_envs, 1))
            self.real_state = np.array([np.random.choice(self.state_dim, p = state_init) for k in range(self.num_envs)], dtype = np.int32)
        self.cur_step = np.zeros(self.num_envs, dtype = np.int32)
        return self.state, {}
    
    def step(self, action) :
        assert (action >= 0).all() and (action < self.action_n).all() # shape : (num_envs,)
        def step_per_env(state : np.ndarray, action: np.int64) -> np.ndarray :
            res = (state.T @ self.trans) * self.emission[:, action]
            return res / np.sum(res)

        vec_step = np.vectorize(step_per_env, signature = '(n),()->(n)')       
        self.state = vec_step(self.state, action)
        return self.state, {} # omit other parameters for simplicity

    def emit_obs_distribution(self) : # shape (num_envs, action_n)
        return self.state

    def emit_obs(self, real) : # shape: (num_envs)
        if real :
            res = np.zeros(self.num_envs, dtype = np.int32)
            for k in range(self.num_envs) : 
                self.real_state[k] = np.random.choice(self.state_dim, p = self.trans[self.real_state[k]])
                res[k] = np.random.choice(self.action_n, p = self.emission[self.real_state[k]])
            return res 
        else :
            assert np.ptp(self.state[0].T @ self.emit_temp) == 0
            return np.random.randint(0, self.action_n, (self.num_envs,))