import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import os
import pickle
import io
from PIL import Image
import matplotlib as mpl
import time


def dist(x):
    return np.sqrt(np.sum(x**2))

def F_spring(pos, nbr, k, L0=1):
    force = k * (L0 - dist(pos-nbr)) * (pos - nbr) / dist(pos-nbr)
    return force

def F_gravity(m, g):
    return np.array([0, -m*g])

def F_damping(v):
    c = 0.5
    return -c * v


def one_step(states, velocities, edges, delta_t=0.1):
    m = 1.0
    L0 = 1

    pos_indices = set(np.arange(len(states)))
    new_states = []
    new_vels = []
    for i in range(len(states)):
        pos = states[i]
        vel = velocities[i]
        nbrs_indices = pos_indices - set([i])
        F_spring_total = 0
        for nbr_indice in nbrs_indices:
            F_ij_spring = F_spring(pos, states[nbr_indice], edges[i, nbr_indice], L0)
            F_spring_total += F_ij_spring
            
        new_v = vel + delta_t * F_spring_total / m
        new_pos = pos + delta_t * new_v
        
        new_vels.append(new_v)
        new_states.append(new_pos)
        
    return np.array(new_states), np.array(new_vels)

def plot_state(particles, edges=None, fig=None, ax=None, color=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))

    if color is None:
        color = 'red'

    ax.scatter(particles[:, 0], particles[:, 1], color=color, s=1000)

    if edges is not None:
        for ix in range(len(edges)):
            nbr_indices = np.where(edges[ix] != 0)[0]
            for idx in nbr_indices:
                x_values = [particles[ix][0], particles[idx][0]]
                y_values = [particles[ix][1], particles[idx][1]]
                ax.plot(x_values, y_values, 'y-')

    ax.set_xlim(-2.0, 2.0)
    ax.set_ylim(-2.0, 2.0)
    ax.set_aspect('equal', adjustable='box')
    ax.set(frame_on=False)
    ax.set_xticks([])
    ax.set_yticks([])
    mpl.rcParams['axes.spines.left'] = False
    mpl.rcParams['axes.spines.right'] = False
    mpl.rcParams['axes.spines.top'] = False
    mpl.rcParams['axes.spines.bottom'] = False

    fig.canvas.draw()
    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    PIL_image = Image.fromarray(image_from_plot)  # .convert('RGB').resize(64, 64)
    width, height = PIL_image.size
    new_height = 64
    ratio = width / height
    new_width = int(ratio * new_height)
    im = PIL_image.resize((new_width, new_height))
    plt.close('all')

    return np.array(im)


def build_matrix_general(n, k):
    # k = array in n * (n-1) / 2
    matrix = np.zeros((n, n))
    k_start_index = 0
    k_end_index = n - 1
    for i in range(n):
        num_zeros = i + 1
        ks = k[k_start_index: k_end_index]

        zeros = np.zeros((num_zeros))

        arr = np.concatenate((zeros, ks))[0:n]
        matrix[i] = arr
        k_start_index = k_end_index
        k_end_index = k_start_index + n - i - 2

    edges = matrix + matrix.T
    return edges

def generate_dataset(num_particles, num_seconds=1, num_params=1000, num_inits=1, data_dir='../data', phase='train', seed=0, imgs=False):

    rng = np.random.default_rng(seed=seed)

    delta_t = 0.1
    num_timesteps = int(num_seconds / delta_t)

    file_path = os.path.join(data_dir, f"deformable/deformable_s_{num_seconds}_num_params_{num_params}_num_inits_{num_inits}_{phase}")
    file_path += "_imgs.pkl" if imgs else ".pkl"
    print(f"Generating dataset to: {file_path}")

    param_dim = num_particles * (num_particles - 1) // 2
    params = np.random.uniform(0.5, 2, size=(num_params, param_dim))

    init_pos = rng.uniform(-1.5, 1.5, size=(num_inits, num_particles, 2))
    init_vel = rng.uniform(-0.4, 0.4, size=(num_inits, 4, 2))

    data = []
    for param in tqdm(params):
        x = []
        y = []
        y_imgs = []
        for i in range(num_inits):
            pos = init_pos[i]
            vel = init_vel[i]
            edge_matrix = build_matrix_general(num_particles, param)
            new_state = pos
            new_vel = vel

            traj = []
            traj.append(new_state)

            img_traj = []
            new_img = plot_state(new_state, edge_matrix).mean(-1)
            img_traj.append(new_img)

            for j in range(num_timesteps):
                new_state, new_vel = one_step(new_state, new_vel, edge_matrix, delta_t=delta_t)
                traj.append(new_state)
                new_img = plot_state(new_state, edge_matrix).mean(-1)
                img_traj.append(new_img)

            final_state = new_state

            traj = np.array(traj)
            img_traj = np.array(img_traj)

            pos_vel = np.concatenate((pos, vel), -1)
            x.append(pos_vel)
            y.append(traj)
            y_imgs.append(img_traj)

        x = np.array(x)
        y = np.array(y)
        y_imgs = np.array(y_imgs)
        data.append([x, y, y_imgs, param])

    with open(file_path, 'wb') as f:
        pickle.dump(data, f)
        print(f"Saved at: {file_path}")

    return data


class DeformableMassSpring(torch.utils.data.Dataset):
    def __init__(self, ns, nq, phase, data_dir='./data', imgs=False, device='cpu'):
        super().__init__()

        self.imgs = imgs
        self.device = device

        if phase == 'train':
            if imgs:
                path = os.path.join(data_dir, f'deformable/deformable_particles_6_s_10_num_params_1000_num_inits_1_train_imgs.pkl')
            else:
                path = os.path.join(data_dir, f'deformable/deformable_s_10_num_params_1000_num_inits_1_train.pkl')
        elif phase == 'test':
            if imgs:
                path = os.path.join(data_dir, f'deformable/deformable_particles_6_s_10_num_params_100_num_inits_1_test_imgs.pkl')
            else:
                path = os.path.join(data_dir, f'deformable/deformable_s_10_num_params_100_num_inits_1_test.pkl')  #
        else:
            print(f"Phase: {phase} doesn't exist")


        self.y_timestep = -1
        self.y_index = 0 # Predict the first particle dynamics

        self.support_size = ns
        self.query_size = nq

        with open(path, 'rb') as f:
            self.data = pickle.load(f)
        print()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        if self.imgs:
            x, y_all, y_imgs, param = self.data[idx]
            y_imgs = torch.from_numpy(y_imgs[0]).float() / 255.
        else:
            x, y_all, _, param = self.data[idx]

        y = torch.reshape(torch.from_numpy(y_all[0, :, self.y_index]), (-1, 2)).float()  # init  # y is a trajectory from init_pos, take the last pos as target
        x = torch.from_numpy(np.linspace(0, 1, y.shape[0])).float().view(-1, 1)

        param = torch.Tensor(param)

        indices = np.arange(y.shape[0])  # Different initial conditions
        np.random.shuffle(indices)

        support_indices = indices[:self.support_size]
        query_indices = indices[self.support_size:]

        x_support = x[support_indices]  # 4 particles x (pos + vel)
        y_support = y[support_indices]  # 4 particle x pos

        x_query = x[query_indices]
        y_query = y[query_indices]

        if self.imgs:
            y_support = y_imgs[support_indices].reshape(self.support_size, -1)
            y_query = y_imgs[query_indices].reshape(len(query_indices), -1)

        return x_support.to(self.device), y_support.to(self.device), x_query.to(self.device), y_query.to(self.device)



if __name__ == '__main__':
    d = generate_dataset(4, num_seconds=10, num_params=1000, num_inits=1, phase='train')
    d = generate_dataset(4, num_seconds=10, num_params=100, num_inits=1, phase='test')

















