import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset


class PathIntegration:
    def __init__(self, params):
        """
        Initialize the data generator with the specified parameters.
        
        Parameters:
        -----------
        n_trials : int
            Number of trials to generate.
        n_timesteps : int
            Number of timesteps per trial.
        v_max : float
            Maximum speed of the agent.
        dim : int
            Dimensionality of the task (1, 2, or 3).
        Additional parameters based on dimensionality:
            1D: x_noise_std
            2D: theta_std, xy_noise_std
            3D: phi_std, theta_std, xyz_noise_std
        """
        self.params = params
        self.n_trials = params['n_trials']
        self.n_timesteps = params['n_timesteps']
        self.v_max = params['v_max']
        self.dim = params['dim']
        self.speed_std = params['speed_std']
        
        if self.dim == 1:
            self.x_noise_std = params['x_noise_std']
        elif self.dim == 2:
            self.theta_std = params['theta_std']
            self.xy_noise_std = params['xy_noise_std']
        elif self.dim == 3:
            self.phi_std = params['phi_std']
            self.theta_std = params['theta_std']
            self.xyz_noise_std = params['xyz_noise_std']
        
        self.stop_mean_duration = params['stop_mean_duration']
        self.go_mean_duration = params['go_mean_duration']
        self.environment_size = params['environment_size']
        
        # self.model = Network(params)        
        
    def generate_train_data(self):
        if self.dim == 1:
            return self.generate_train_data_1d()
        elif self.dim == 2:
            return self.generate_train_data_2d()
        elif self.dim == 3:
            return self.generate_train_data_3d()
    
    
    def generate_train_data_1d(self):
        """
        Generate synthetic data for 1D path integration tasks.

        This function simulates an agent moving in a 1D line environment,
        alternating between periods of motion and pauses.

        Returns:
        --------
        numpy.ndarray
            Array of shape (2, n_trials, n_timesteps) containing:
            - [0, :, :]: x-coordinate
            - [1, :, :]: speed
        """
        # Pre-compute random increments
        speed_incs = np.random.normal(scale=self.v_max / 10, size=self.n_timesteps * self.n_trials * 2)
        x_incs = np.random.normal(scale=self.x_noise_std, size=self.n_timesteps * self.n_trials * 2)
        stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        # Initialize data array
        data = np.zeros((2, self.n_trials, self.n_timesteps))

        k = 0  # Index for random updates
        stop_go_idx = 0  # Index for stopping and starting random choices

        for n in range(self.n_trials):
            moving = True
            for t in range(1, self.n_timesteps):
                if moving:
                    # Update speed
                    data[1, n, t] = np.clip(data[1, n, t - 1] + speed_incs[k], 0, self.v_max)

                    # Update position
                    dx = data[1, n, t] + x_incs[k]
                    new_x = data[0, n, t - 1] + dx

                    if abs(new_x) < self.environment_size / 2:
                        data[0, n, t] = new_x
                    else:
                        # If outside, reverse direction (simplified boundary handling)
                        data[0, n, t] = data[0, n, t - 1] - dx

                    # Possibly stop moving
                    if stop_go[stop_go_idx] < 1.0 / self.go_mean_duration:
                        moving = False
                else:
                    # Stay in place
                    data[:, n, t] = data[:, n, t - 1]
                    data[1, n, t] = 0.0  # Set speed to 0

                    # Possibly start moving
                    if stop_go[stop_go_idx] < 1.0 / self.stop_mean_duration:
                        moving = True

                k += 1
                stop_go_idx += 1

                # Regenerate random values if necessary
                if k >= len(speed_incs):
                    k = 0
                    speed_incs = np.random.normal(scale=self.v_max / 10, size=self.n_timesteps * self.n_trials * 2)
                    x_incs = np.random.normal(scale=self.x_noise_std, size=self.n_timesteps * self.n_trials * 2)

                if stop_go_idx >= len(stop_go):
                    stop_go_idx = 0
                    stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        input, output = data[1:], data[:1]
        input = input.transpose(1, 2, 0)
        output = output.transpose(1, 2, 0)
        
        return input, output


    def generate_train_data_2d(self):
        """
        Generate synthetic data for path integration tasks.

        This function simulates an agent moving in a 2D square environment,
        alternating between periods of motion and pauses.

        Returns:
        --------
        numpy.ndarray
            Array of shape (4, n_trials, n_timesteps) containing:
            - [0, :, :]: x-coordinate
            - [1, :, :]: y-coordinate
            - [2, :, :]: theta (direction)
            - [3, :, :]: speed
        """
        # Pre-compute random increments
        theta_incs = np.random.normal(scale=self.theta_std, size=self.n_timesteps * self.n_trials * 2)
        speed_incs = np.random.normal(scale=self.speed_std, size=self.n_timesteps * self.n_trials * 2)
        x_incs = np.random.normal(scale=self.xy_noise_std, size=self.n_timesteps * self.n_trials * 2)
        y_incs = np.random.normal(scale=self.xy_noise_std, size=self.n_timesteps * self.n_trials * 2)
        stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        # Initialize data array
        data = np.zeros((4, self.n_trials, self.n_timesteps))

        # Set initial conditions
        data[2, :, 0] = np.random.uniform(0, 2 * np.pi, self.n_trials)  # Initial direction

        k = 0  # Index for random updates
        stop_go_idx = 0  # Index for stopping and starting random choices

        for n in range(self.n_trials):
            moving = True
            for t in range(1, self.n_timesteps):
                if moving:
                    # Update direction and speed
                    data[2, n, t] = (data[2, n, t - 1] + theta_incs[k]) % (2 * np.pi)
                    data[3, n, t] = np.clip(data[3, n, t - 1] + speed_incs[k], 0, self.v_max)

                    # Update position
                    dx = data[3, n, t] * np.cos(data[2, n, t]) + x_incs[k]
                    dy = data[3, n, t] * np.sin(data[2, n, t]) + y_incs[k]

                    # Check if new position is within the environment
                    new_x = data[0, n, t - 1] + dx
                    new_y = data[1, n, t - 1] + dy

                    if (
                        abs(new_x) < self.environment_size / 2
                        and abs(new_y) < self.environment_size / 2
                    ):
                        data[0, n, t] = new_x
                        data[1, n, t] = new_y
                    else:
                        # If outside, rotate and try again (simplified boundary handling)
                        data[2, n, t] += np.pi
                        data[0, n, t] = data[0, n, t - 1]
                        data[1, n, t] = data[1, n, t - 1]

                    # Possibly stop moving
                    if stop_go[stop_go_idx] < 1.0 / self.go_mean_duration:
                        moving = False
                else:
                    # Stay in place
                    data[:, n, t] = data[:, n, t - 1]
                    data[3, n, t] = 0.0  # Set speed to 0

                    # Possibly start moving
                    if stop_go[stop_go_idx] < 1.0 / self.stop_mean_duration:
                        moving = True

                k += 1
                stop_go_idx += 1

                # Regenerate random values if necessary
                if k >= len(theta_incs):
                    k = 0
                    theta_incs = np.random.normal(
                        scale=self.theta_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    speed_incs = np.random.normal(
                        scale=self.speed_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    x_incs = np.random.normal(
                        scale=self.xy_noise_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    y_incs = np.random.normal(
                        scale=self.xy_noise_std, size=self.n_timesteps * self.n_trials * 2
                    )

                if stop_go_idx >= len(stop_go):
                    stop_go_idx = 0
                    stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        input, output = data[2:], data[:2]
        input = input.transpose(1, 2, 0)
        output = output.transpose(1, 2, 0)
        
        return input, output
    

    def generate_train_data_3d(self):
        """
        Generate synthetic data for 3D path integration tasks.

        This function simulates an agent moving in a 3D cubic environment,
        alternating between periods of motion and pauses.

        Returns:
        --------
        numpy.ndarray
            Array of shape (6, n_trials, n_timesteps) containing:
            - [0, :, :]: x-coordinate
            - [1, :, :]: y-coordinate
            - [2, :, :]: z-coordinate
            - [3, :, :]: phi (azimuthal angle)
            - [4, :, :]: theta (elevation angle)
            - [5, :, :]: speed
        """
        # Pre-compute random increments
        phi_incs = np.random.normal(scale=self.phi_std, size=self.n_timesteps * self.n_trials * 2)
        theta_incs = np.random.normal(scale=self.theta_std, size=self.n_timesteps * self.n_trials * 2)
        speed_incs = np.random.normal(scale=self.speed_std, size=self.n_timesteps * self.n_trials * 2)
        x_incs = np.random.normal(scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2)
        y_incs = np.random.normal(scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2)
        z_incs = np.random.normal(scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2)
        stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        # Initialize data array
        data = np.zeros((6, self.n_trials, self.n_timesteps))

        # Set initial conditions
        data[3, :, 0] = np.random.uniform(0, 2 * np.pi, self.n_trials)  # Initial azimuthal angle
        data[4, :, 0] = np.random.uniform(0, np.pi, self.n_trials)  # Initial elevation angle

        k = 0  # Index for random updates
        stop_go_idx = 0  # Index for stopping and starting random choices

        for n in range(self.n_trials):
            moving = True
            for t in range(1, self.n_timesteps):
                if moving:
                    # Update angles and speed
                    data[3, n, t] = (data[3, n, t - 1] + phi_incs[k]) % (2 * np.pi)
                    data[4, n, t] = np.clip(data[4, n, t - 1] + theta_incs[k], 0, np.pi)
                    data[5, n, t] = np.clip(data[5, n, t - 1] + speed_incs[k], 0, self.v_max)

                    # Update position
                    dx = data[5, n, t] * np.sin(data[4, n, t]) * np.cos(data[3, n, t]) + x_incs[k]
                    dy = data[5, n, t] * np.sin(data[4, n, t]) * np.sin(data[3, n, t]) + y_incs[k]
                    dz = data[5, n, t] * np.cos(data[4, n, t]) + z_incs[k]

                    # Check if new position is within the environment
                    new_x = data[0, n, t - 1] + dx
                    new_y = data[1, n, t - 1] + dy
                    new_z = data[2, n, t - 1] + dz

                    if (
                        abs(new_x) < self.environment_size / 2
                        and abs(new_y) < self.environment_size / 2
                        and abs(new_z) < self.environment_size / 2
                    ):
                        data[0, n, t] = new_x
                        data[1, n, t] = new_y
                        data[2, n, t] = new_z
                    else:
                        # If outside, rotate and try again (simplified boundary handling)
                        data[3, n, t] += np.pi
                        data[0, n, t] = data[0, n, t - 1]
                        data[1, n, t] = data[1, n, t - 1]
                        data[2, n, t] = data[2, n, t - 1]

                    # Possibly stop moving
                    if stop_go[stop_go_idx] < 1.0 / self.go_mean_duration:
                        moving = False
                else:
                    # Stay in place
                    data[:, n, t] = data[:, n, t - 1]
                    data[5, n, t] = 0.0  # Set speed to 0

                    # Possibly start moving
                    if stop_go[stop_go_idx] < 1.0 / self.stop_mean_duration:
                        moving = True

                k += 1
                stop_go_idx += 1

                # Regenerate random values if necessary
                if k >= len(phi_incs):
                    k = 0
                    phi_incs = np.random.normal(
                        scale=self.phi_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    theta_incs = np.random.normal(
                        scale=self.theta_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    speed_incs = np.random.normal(
                        scale=self.speed_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    x_incs = np.random.normal(
                        scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    y_incs = np.random.normal(
                        scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2
                    )
                    z_incs = np.random.normal(
                        scale=self.xyz_noise_std, size=self.n_timesteps * self.n_trials * 2
                    )

                if stop_go_idx >= len(stop_go):
                    stop_go_idx = 0
                    stop_go = np.random.uniform(size=self.n_timesteps * self.n_trials)

        input, output = data[3:], data[:3]
        input = input.transpose(1, 2, 0)
        output = output.transpose(1, 2, 0)
        
        return input, output



    def get_train_loader(self):
        inputs, labels = self.generate_train_data()
        inputs = torch.from_numpy(inputs).to(self.params['device']).float()
        labels = torch.from_numpy(labels).to(self.params['device']).float()
        train_loader = DataLoader(dataset=TensorDataset(inputs, labels), batch_size=self.params['n_batch'], shuffle=True)
        return train_loader


    def generate_input_for_PCA_plot(self):
        """
        Generate synthetic data for path integration tasks with constant speed.

        The agent moves in integer steps in the four cardinal directions at a constant speed,
        ensuring that all positions have integer x and y values.

        Returns:
        --------
        numpy.ndarray
            Array of shape (4, n_trials, n_timesteps) containing:
            - [0, :, :]: x-coordinate
            - [1, :, :]: y-coordinate
            - [2, :, :]: theta (direction)
            - [3, :, :]: speed (constant)
        """
        import numpy as np

        # Define possible directions: 0 (right), pi/2 (up), pi (left), 3pi/2 (down)
        directions = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
        direction_change_prob = 0.5  # Probability of changing direction at each step

        # Initialize data array
        data = np.zeros((4, self.n_trials, self.n_timesteps))

        # Set initial conditions
        data[0, :, 0] = 0  # Initial x-coordinate (integer)
        data[1, :, 0] = 0  # Initial y-coordinate (integer)
        data[2, :, 0] = np.random.choice(directions, self.n_trials)  # Initial direction
        data[3, :, :] = 0.5  # Set speed to constant value 1 for all time steps

        for n in range(self.n_trials):
            for t in range(1, self.n_timesteps):
                # Decide whether to change direction
                if np.random.rand() < direction_change_prob:
                    # Change to a new random direction
                    data[2, n, t] = np.random.choice(directions)
                else:
                    # Keep the same direction
                    data[2, n, t] = data[2, n, t - 1]

                # Since speed is constant, we don't need to update it
                # Compute integer increments based on direction
                dx = data[3, n, t] * np.cos(data[2, n, t])
                dy = data[3, n, t] * np.sin(data[2, n, t])

                # Update positions with integer increments
                new_x = data[0, n, t - 1] + dx
                new_y = data[1, n, t - 1] + dy

                # Check if new position is within the environment boundaries
                if (
                    abs(new_x) < self.environment_size / 2
                    and abs(new_y) < self.environment_size / 2
                ):
                    data[0, n, t] = new_x
                    data[1, n, t] = new_y
                else:
                    # Reverse direction if hitting the boundary
                    data[2, n, t] = (data[2, n, t - 1] + np.pi) % (2 * np.pi)
                    # Update positions after changing direction
                    dx = int(data[3, n, t] * np.cos(data[2, n, t]))
                    dy = int(data[3, n, t] * np.sin(data[2, n, t]))
                    data[0, n, t] = data[0, n, t - 1] + dx
                    data[1, n, t] = data[1, n, t - 1] + dy

        # Prepare input and output data
        input_data = data[2:, :, :]  # Direction and speed
        output_data = data[:2, :, :]  # x and y coordinates

        # Transpose to match expected input shape (n_trials, n_timesteps, features)
        input_data = input_data.transpose(1, 2, 0)
        output_data = output_data.transpose(1, 2, 0)

        return input_data, output_data


