import numpy as np
import argparse
import torch
import time
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.integrate import odeint

class MultiNodeSimulation:
    def __init__(self, n, level="easy", kp=0.2, kv=0.15, kd=0.15, min_distance=2):
        self.n = n
        self.kp = kp
        self.kv = kv
        self.kd = kd
        self.min_distance = min_distance
        self.level = level
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def control_rate(self, state, t):
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        p = state[:3 * self.n].reshape((self.n, 3))
        v = state[3 * self.n:].reshape((self.n, 3))
        dvdt = torch.zeros((self.n, 3), device=self.device)

        p_i = p.unsqueeze(1)
        p_j = p.unsqueeze(0)
        v_i = v.unsqueeze(1)
        v_j = v.unsqueeze(0)

        mask = ~torch.eye(self.n, dtype=torch.bool, device=self.device).unsqueeze(-1)

        if self.level == "easy":
            omega_ij = 0.2 * torch.abs(torch.sum(p_i * p_j, dim=-1) + torch.sum(v_i * v_j, dim=-1)).unsqueeze(-1)
        elif self.level == "medium":
            cross_p = torch.cross(p_i, p_j, dim=-1)
            omega_ij = 0.2 * torch.abs(torch.sum(p_i * p_j, dim=-1) + torch.sum(v_i * v_j, dim=-1) + torch.sum(v_i * cross_p, dim=-1)).unsqueeze(-1)
        elif self.level == "difficult":
            cross_p = torch.cross(p_i, p_j, dim=-1)
            cross_v = torch.cross(v_i, v_j, dim=-1)
            omega_ij = 0.15 * torch.abs(
                torch.sum(p_i * p_j, dim=-1) + torch.sum(v_i * v_j, dim=-1)
                + torch.sum(v_i * cross_p, dim=-1) + torch.sum(v_j * cross_p, dim=-1)
                + torch.sum(p_i * cross_v, dim=-1) + torch.sum(p_j * cross_v, dim=-1)
                + torch.sum(cross_p * cross_v, dim=-1)
            ).unsqueeze(-1)
        else:
            raise Exception("No such level")

        p_diff = p_i - p_j
        v_diff = v_i - v_j
        cross_product = torch.cross(v_j, v_i, dim=-1)

        dvdt = -torch.sum(omega_ij * (self.kp * p_diff + self.kv * v_diff + self.kd * cross_product) * mask, dim=1)

        dstate_dt = torch.cat((v.flatten(), dvdt.flatten())).cpu().numpy()
        return dstate_dt
    def generate_valid_initial_positions(self):
        positions = []
        while len(positions) < self.n:
            
            r = np.random.uniform(low=1, high=3) 
            theta = np.random.uniform(0, 2 * np.pi)
            phi = np.random.uniform(0, np.pi)
            new_position = r * np.array([np.sin(phi) * np.cos(theta), 
                                    np.sin(phi) * np.sin(theta), 
                                    np.cos(phi)])

            valid = True
            for existing in positions:
                if np.linalg.norm(new_position - existing) <= self.min_distance:
                    valid = False
                    break
            if valid:
                positions.append(new_position)
        return np.array(positions).flatten()

    def sample_trajectory(self, T, num_points):
        p0 = self.generate_valid_initial_positions()
        v0 = np.random.randn(self.n * 3)
        initial_state = np.concatenate((p0, v0))

        t = np.linspace(0, T, num_points)
        solution = odeint(self.control_rate, initial_state, t, atol=1e-8, rtol=1e-8)

        p_solution = solution[:, :3 * self.n].reshape((-1, self.n, 3))
        v_solution = solution[:, 3 * self.n:].reshape((-1, self.n, 3))

        edges = np.ones((self.n, self.n)) - np.eye(self.n)

        return p_solution, v_solution, edges


def generate_dataset(sim, num_sims, length, num_points):
    loc_all = list()
    vel_all = list()
    edges_all = list()

    for i in range(num_sims):
        t = time.time()
        loc, vel, edges = sim.sample_trajectory(T=length, num_points=num_points)
        if i % 5 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)
        vel_all.append(vel)
        edges_all.append(edges)

    loc_all = np.stack(loc_all)
    vel_all = np.stack(vel_all)
    edges_all = np.stack(edges_all)

    return loc_all, vel_all, edges_all


def plot_trajectory(loc, vel, folder_name):

    plt.rcParams.update({'font.size': 16})

    plt.figure(figsize=(10, 6))
    for i in range(loc.shape[1]):
        v_magnitude = np.linalg.norm(vel[:, i, :], axis=1)
        plt.plot(np.arange(len(v_magnitude)), v_magnitude, label=f'Agent {i + 1}')
    plt.xlabel('Time Step')
    plt.ylabel('Velocity Magnitude')
    plt.title('Velocity Magnitude Over Time')
    plt.legend(fontsize=14) 
    plt.grid(True)
    velocity_plot_path = os.path.join(folder_name, 'velocity_plot.png')
    plt.savefig(velocity_plot_path, dpi=300, bbox_inches='tight')
    plt.close()

    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111, projection='3d')
    for i in range(loc.shape[1]):
        ax.plot(loc[:, i, 0], loc[:, i, 1], loc[:, i, 2], label=f'Agent {i + 1}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Position Trajectories')
    ax.legend(fontsize=14)  
    position_plot_path = os.path.join(folder_name, '3d_position_plot.png')
    plt.savefig(position_plot_path, dpi=300, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-train', type=int, default=7000,
                        help='Number of training simulations to generate.')
    parser.add_argument('--num-valid', type=int, default=2000,
                        help='Number of validation simulations to generate.')
    parser.add_argument('--num-test', type=int, default=1000,
                        help='Number of test simulations to generate.')
    parser.add_argument('--length', type=int, default=5,
                        help='Length of trajectory.')
    parser.add_argument('--length_test', type=int, default=5,
                        help='Length of test set trajectory.')
    parser.add_argument('--num_points', type=int, default=1000,
                        help='How often to sample the trajectory.')
    parser.add_argument('--level', type=str, default="easy",
                        help='easy, medium, difficult, hell.')
    parser.add_argument('--n_balls', type=int, default=5,
                        help='Number of balls in the simulation.')
    parser.add_argument('--seed', type=int, default=40,
                        help='Random seed.')
    parser.add_argument('--sufix', type=str, default="",
                        help='add a sufix to the name')
    parser.add_argument('--plot', action='store_true',
                        help='Plot sample trajectory before generating dataset')

    args = parser.parse_args()

    suffix = "consensus" + str(args.n_balls) + args.sufix

    sim = MultiNodeSimulation(n=args.n_balls, level=args.level)

    folder_name = args.level
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    # if 1:
    #     print("Generating sample trajectory for plotting...")
    #     loc_sample, vel_sample, edges_sample = sim.sample_trajectory(
    #         T=args.length, num_points=args.num_points)
    #     plot_trajectory(loc_sample, vel_sample, folder_name)
    
    print("Generating {} training simulations".format(args.num_train))
    loc_train, vel_train, edges_train = generate_dataset(sim, args.num_train,
                                                         args.length,
                                                         args.num_points)

    print("Generating {} validation simulations".format(args.num_valid))
    loc_valid, vel_valid, edges_valid = generate_dataset(sim, args.num_valid,
                                                         args.length,
                                                         args.num_points)

    print("Generating {} test simulations".format(args.num_test))
    loc_test, vel_test, edges_test = generate_dataset(sim, args.num_test,
                                                      args.length_test,
                                                      args.num_points)

    np.save(os.path.join(folder_name, 'loc_train' + suffix + '.npy'), loc_train)
    np.save(os.path.join(folder_name, 'vel_train' + suffix + '.npy'), vel_train)
    np.save(os.path.join(folder_name, 'edges_train' + suffix + '.npy'), edges_train)

    np.save(os.path.join(folder_name, 'loc_valid' + suffix + '.npy'), loc_valid)
    np.save(os.path.join(folder_name, 'vel_valid' + suffix + '.npy'), vel_valid)
    np.save(os.path.join(folder_name, 'edges_valid' + suffix + '.npy'), edges_valid)

    np.save(os.path.join(folder_name, 'loc_test' + suffix + '.npy'), loc_test)
    np.save(os.path.join(folder_name, 'vel_test' + suffix + '.npy'), vel_test)
    np.save(os.path.join(folder_name, 'edges_test' + suffix + '.npy'), edges_test)