import pickle

import numpy as np

np.random.seed(42)

from typing import Dict, List, Set, Tuple

import matplotlib.pyplot as plt
import torch

cmap = plt.get_cmap("RdBu_r")
colors = [cmap(i / 9) for i in range(10)]


def visualise(data: np.ndarray):
    plt.scatter(data[:, 0], data[:, 1])
    plt.show()


def visualise_overtime(data: np.ndarray):
    for time in range(data.shape[0]):
        plt.scatter(
            data[time, :, 0],
            data[time, :, 1],
            label=f"time: {time}",
            color=colors[time],
        )

    plt.legend()
    plt.show()


def sliced_plot(
    x1: int, y1: int, x2: int, y2: int, num_slice: int = 5
) -> Tuple[np.ndarray, np.ndarray]:
    x_slice = np.linspace(x1, x2, num_slice)
    y_slice = np.linspace(y1, y2, num_slice)
    return x_slice, y_slice


def create_trace_linear(
    traces: List[Tuple[int]] = [
        (-3, 3, 3, -3),
        (3, -3, -3, 3),
        (-3, -3, 3, 3),
        (3, 3, -3, -3),
    ],
    num_slice: int = 10,
) -> List[np.ndarray]:
    data_traces = []
    for trace in traces:
        print(trace)
        x1, y1, x2, y2 = trace
        x_slice, y_slice = sliced_plot(x1, y1, x2, y2, num_slice=num_slice)
        shifts = np.vstack((x_slice, y_slice)).T

        data = []
        for time in range(num_slice):
            d = np.random.normal(loc=shifts[time], size=(200, 2))
            data.append(d)

        data = np.array(data)
        data_traces.append(data)
        visualise_overtime(data)

    return data_traces


def create_trace_joint(
    traces: List[List[Tuple[int]]] = [
        [(-3, 3, 3, 0), (-3, -3, 3, 0)],
        [(-3, 3, 0, -3), (3, 3, 0, -3)],
        [(-3, -3, 0, 3), (3, -3, 0, 3)],
        [(3, 3, -3, 0), (3, -3, -3, 0)],
    ],
    num_slice: int = 10,
) -> List[np.ndarray]:
    data_traces = []
    for trace in traces:
        print(trace)
        x1_1, y1_1, x2_1, y2_1 = trace[0]
        x_slice_1, y_slice_1 = sliced_plot(x1_1, y1_1, x2_1, y2_1, num_slice=num_slice)
        shifts_1 = np.vstack((x_slice_1, y_slice_1)).T

        x1_2, y1_2, x2_2, y2_2 = trace[1]
        x_slice_2, y_slice_2 = sliced_plot(x1_2, y1_2, x2_2, y2_2, num_slice=num_slice)
        shifts_2 = np.vstack((x_slice_2, y_slice_2)).T

        data = []
        for time in range(num_slice):
            d_1 = np.random.normal(loc=shifts_1[time], size=(100, 2))
            d_2 = np.random.normal(loc=shifts_2[time], size=(100, 2))
            d = np.concatenate((d_1, d_2), axis=0)
            data.append(d)

        data = np.array(data)
        data_traces.append(data)
        visualise_overtime(data)

    return data_traces


def create_trace_separate(
    traces: List[List[Tuple[int]]] = [
        [(3, 0, -3, 3), (3, 0, -3, -3)],
        [(0, -3, -3, 3), (0, -3, 3, 3)],
        [(0, 3, -3, -3), (0, 3, 3, -3)],
        [(-3, 0, 3, 3), (-3, 0, 3, -3)],
    ],
    num_slice: int = 10,
) -> List[np.ndarray]:
    data_traces = []
    for trace in traces:
        print(trace)
        x1_1, y1_1, x2_1, y2_1 = trace[0]
        x_slice_1, y_slice_1 = sliced_plot(x1_1, y1_1, x2_1, y2_1, num_slice=num_slice)
        shifts_1 = np.vstack((x_slice_1, y_slice_1)).T

        x1_2, y1_2, x2_2, y2_2 = trace[1]
        x_slice_2, y_slice_2 = sliced_plot(x1_2, y1_2, x2_2, y2_2, num_slice=num_slice)
        shifts_2 = np.vstack((x_slice_2, y_slice_2)).T

        data = []
        for time in range(num_slice):
            d_1 = np.random.normal(loc=shifts_1[time], size=(100, 2))
            d_2 = np.random.normal(loc=shifts_2[time], size=(100, 2))
            d = np.concatenate((d_1, d_2), axis=0)
            data.append(d)

        data = np.array(data)
        data_traces.append(data)
        visualise_overtime(data)

    return data_traces


time_slice = 10
data_linear = create_trace_linear(num_slice=time_slice)
data_joint = create_trace_joint(num_slice=time_slice)
data_separate = create_trace_separate(num_slice=time_slice)

labels_linear = ["linear_1", "linear_2", "linear_3", "linear_4"]
labels_joint = ["joint_1", "joint_2", "joint_3", "joint_4"]
labels_separate = ["separate_1", "separate_2", "separate_3", "separate_4"]

label2vecs_all = [{} for _ in range(time_slice)]

for time in range(time_slice):
    print(f"time: {time}")
    for label_linear, d_linear in zip(labels_linear, data_linear):
        print(label_linear)
        print(d_linear[time][:3])
        label2vecs_all[time][label_linear] = torch.from_numpy(d_linear[time])

    for label_joint, d_joint in zip(labels_joint, data_joint):
        print(label_joint)
        print(d_joint[time][:3])
        label2vecs_all[time][label_joint] = torch.from_numpy(d_joint[time])

    for label_separate, d_separate in zip(labels_separate, data_separate):
        print(label_separate)
        print(d_separate[time][:3])
        label2vecs_all[time][label_separate] = torch.from_numpy(d_separate[time])

for time in range(time_slice):
    label2vecs = label2vecs_all[time]
    pickle.dump(label2vecs, open(f"pseudo_label2vecs_{time}.pkl", "wb"))
