import torch
from helpers import gen_matrix, DIAG_MATS, NON_DIAG_MATS
from enum import Enum
from tqdm import tqdm

class DataFormat(Enum):
    POINTS = 1 # points in vector space, sampled from underlying distribution
    DISTRIB = 2 # distribution, with 2d vector where first d entries specify mean, next d specify variance along std basis vectors
    POINT_SET = 3 # set of points from unkown underlying distribution


class NoisyDynSys:
    '''
    Generates data based on an autoregressive system of the form 
        x_n = A x_n + w_n
        y_n = C x_n + v_n

    TODO: only works for fully observed systems when dim_x = dim_y. This should maybe be extended
    for the POINTS data format, supports both full and diagonal matrices.
    '''

    def __init__(self, state_dim, output_dim, device) -> None:
        # Assumes all noise is 0 mean for now
        self.device = device
        self.state_dim = state_dim
        self.output_dim = output_dim
        #self.state0 = 0*torch.ones(size=(state_dim, 1), device=self.device) ## use same initial state for all sequences
        self.state0 = torch.normal(0.0, 1e-3, size=(state_dim, 1), device=self.device)

    def get_data(self, batch_sz, seq_len, A_type, C_type, single_sys=False,
                                                is_diag=True, s_noise_var=None, m_noise_var=None, silent=False):
        '''
            batch_sz: positive integer >= 1
            seq_len: positive integer >= 1
            A_type: a MatType enum constant
            single_sys: boolean stating whether all runs are produced by the same underlying system
            case: type of data returned, based on the enum DataFormat
            is_diag: boolean stating if all involved matrices are diagonal (computation is optimized)
            s_noise_var: a variable of dimension batch_sz x state_dim x 1 containing variance vectors (so cov matrices are always diagonal)
            m_noise_var: a variable of dimension batch_sz x state_dim x 1 containing variance vectors (so cov matrices are always diagonal)
        '''
        self.__check_parameters(A_type, C_type, is_diag, batch_sz, seq_len)
        
        silent or print("Generating dynamical system data ...")
        init_state = self.state0
        data_dim = self.state_dim

        m_noise = m_noise_var is not None
        s_noise = s_noise_var is not None
        As, Cs = self.__create_transitions(batch_sz, self.state_dim, self.output_dim, A_type, C_type, single_sys,
                                                            is_diag, m_noise, s_noise)

        ys, xs = self.__evolve_system(batch_sz, seq_len, is_diag, As, Cs,
                                            init_state, self.state_dim, self.output_dim, s_noise_var, m_noise_var)

        if is_diag:
            assert self.state_dim == self.output_dim, "For diagonal matrices, state_dim and output_dim must be equal"
            Qw = s_noise_var[:, :, 0] if s_noise else torch.zeros_like(batch_sz, self.state_dim, device=self.device)
            Qv = m_noise_var[:, :, 0] if m_noise else torch.zeros_like(batch_sz, self.output_dim, device=self.device)
        else:
            hlp_Qw = torch.eye(self.state_dim, device=self.device).unsqueeze(0).repeat(batch_sz, 1, 1)
            Qw = torch.mul(s_noise_var.transpose(-2, -1), hlp_Qw)
            hlp_Qv = torch.eye(self.output_dim, device=self.device).unsqueeze(0).repeat(batch_sz, 1, 1)
            Qv = torch.mul(m_noise_var.transpose(-2, -1), hlp_Qv)

        silent or print("\n *** Finished generating dynamical system data.\n")
        return ys, As, Cs, init_state, xs, Qw, Qv

    def __check_parameters(self, A_type, C_type, is_diag, batch_size, seq_len):
        assert batch_size >= 1 and seq_len >= 1
        
        if is_diag:
            assert A_type in DIAG_MATS

        assert (A_type in DIAG_MATS and C_type in DIAG_MATS) \
                or (A_type in NON_DIAG_MATS and C_type in NON_DIAG_MATS)
    
    def __create_transitions(self, batch_sz, state_dim, output_dim, A_type, C_type, single_sys, is_diag, m_noise, s_noise):
        if is_diag:
            # A and C are 1D tensors
            assert state_dim == output_dim, "For diagonal matrices, state_dim and output_dim must be equal"
            transitions = torch.zeros(batch_sz, state_dim, device=self.device)
            measurements = torch.zeros(batch_sz, state_dim, device=self.device)
            A = torch.diag(gen_matrix(state_dim, state_dim, A_type, device=self.device)) # extract diagonal
            C = torch.diag(gen_matrix(state_dim, state_dim, C_type, device=self.device))

        else:
            transitions = torch.zeros(batch_sz, state_dim, state_dim, device=self.device)
            measurements = torch.zeros(batch_sz, output_dim, state_dim, device=self.device)
            A = gen_matrix(state_dim, state_dim, A_type, device=self.device)
            C = gen_matrix(output_dim, state_dim, C_type, device=self.device)

        # Create transition matrices
        for i in range(batch_sz):
            self.__add_transition(transitions, i, A, is_diag, m_noise, s_noise)
            self.__add_transition(measurements, i, C, is_diag, m_noise, s_noise)

            if not single_sys:
                if is_diag:
                    assert state_dim == output_dim, "For diagonal matrices, state_dim and output_dim must be equal"
                    A = torch.diag(gen_matrix(state_dim, state_dim, A_type, device=self.device))
                    C = torch.diag(gen_matrix(state_dim, state_dim, C_type, device=self.device))
                else:
                    A = gen_matrix(state_dim, state_dim, A_type, device=self.device)
                    C = gen_matrix(output_dim, state_dim, C_type, device=self.device)

        return transitions, measurements
    
    def __add_transition(self, transitions, idx, W_update, is_diag, m_noise, s_noise):
        if is_diag:
            transitions[idx, :] = W_update
        else:
            transitions[idx, :, :] = W_update


    def __evolve_system(self, batch_sz, seq_len, is_diag, transitions, measurements,  
                                                        init_state, state_dim, output_dim, s_noise_var, m_noise_var):
        '''
        returns:
        y of size batch_sz x seqlen x dim, which are the observations of the noisy dyn sys. Starts from y_1
        since y_0 is just C * x_0 which is deterministic and we don't care.
        '''

        # Initialize and prepare noise terms to be the appropriate dimension
        y = torch.zeros(batch_sz, seq_len, output_dim, device=self.device)
        x_seq = torch.zeros(batch_sz, seq_len, state_dim, device=self.device)

        if is_diag:
            x = init_state.transpose(0, 1).repeat(batch_sz, 1)
            # y[:, 0, :] = measurements * x[:, :]

        else:
            x = init_state.unsqueeze(0).repeat(batch_sz, 1, 1)
            # y[:, 0, :] =  torch.bmm(measurements, torch.bmm(transitions, x))[:, :, 0]
        
        s_var, m_var = self.__prepare_noise_tensors(is_diag, state_dim, output_dim, s_noise_var, m_noise_var)

        # evolve system
        for j in range(seq_len):
            x = self.__compute_new_state(x, transitions, is_diag, s_var)
            x_seq[:, j, :] = x if is_diag else x[:, :, 0]

            self.__update_observations(y, j, x, measurements, is_diag, m_var) 

        return y, x_seq
    
    def __prepare_noise_tensors(self, is_diag, state_dim, output_dim, s_noise_var, m_noise_var):
        s_var = None
        m_var = None

        if m_noise_var is not None:
            m_var = m_noise_var[:, :, 0]
            #print("m_var shape: ", m_var.shape)
            #print("m_noise_var shape: ", m_noise_var.shape)

        if s_noise_var is not None:
            s_var = s_noise_var[:, :, 0] if is_diag else s_noise_var

        return s_var, m_var
    
    def __compute_new_state(self, old_state, transitions, is_diag, s_var):
        new_state = torch.bmm(transitions, old_state) if not is_diag else transitions * old_state
        
        if s_var is not None:
            new_state += torch.normal(0.0, s_var)

        return new_state
    
    def __update_observations(self, observations, idx, new_state, measurements, is_diag, m_var):
        observations[:, idx, :] = torch.bmm(measurements, new_state)[:, :, 0] if not is_diag else measurements * new_state
        
        if m_var is not None:
            # TODO: this dimension here does not fit the is_diag stuff
            observations[:, idx, :] += torch.normal(0.0, m_var)