from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import special_ortho_group
from scipy.spatial.transform import Rotation
from data_utils.chair_dataset import render, read_off
from matplotlib.animation import FuncAnimation, PillowWriter, ArtistAnimation
import torch
from loss_functions import normlize_vector
from sklearn.decomposition import PCA
from utils import get_device

def save_screen_animation(images, path, axis='x'):
    fig = plt.figure(figsize=(4, 4))
    plt.axis("off")
    ims = [
        [plt.imshow(images[j], animated=True, cmap="binary", interpolation='spline36')]
        for j in range(images.shape[0])
    ]
    anim = ArtistAnimation(
        fig, ims, interval=1000, repeat_delay=1000, blit=True
    )
    anim.save(path + f"/Screen_animation_{axis}.gif", dpi=128, writer="imagemagick")
    plt.show()


class EvalModule:
    def __init__(
            self,
            args,
            enc
            ):
        self.args = args
        self.enc = enc
        self.enc.eval()
        self.device = get_device(enc)
        self.init_rot = special_ortho_group.rvs(3)

    def visualize(self, axis='x'):
        print(f'Visualizing rotation along {axis} axis ...')
        N_THETA = 120
        viz_dataset = np.empty((N_THETA, self.args.image_channels, self.args.img_w, self.args.img_h))
        vertices, triangles = read_off(self.args.data_dir + '/chair.off')
        vertices -= np.mean(vertices, axis=0)
        vertices /= np.mean(np.linalg.norm(vertices, axis=1))
        Rotations_list = []
        for j, t2 in enumerate(tqdm(np.linspace(0, 2 * np.pi, N_THETA))):
            if axis is 'x':
                R = Rotation.from_euler('xyz', [t2, 0, 0]).as_matrix()
            elif axis is 'y':
                R = Rotation.from_euler('xyz', [0, t2, 0]).as_matrix()
            elif axis is 'z':
                R = Rotation.from_euler('xyz', [0, 0, t2]).as_matrix()
            Rotations_list.append(R)
            viz_dataset[j] = render(vertices, triangles, R @ self.init_rot)
        save_screen_animation(viz_dataset[:, 0, :, :], self.args.results_dir, axis=axis)
        viz_tensors = torch.Tensor(viz_dataset).to(self.device)
        codes = self.enc(viz_tensors)
        normalized_codes = normlize_vector(codes)
        X = normalized_codes.detach().cpu().numpy()
        pca = PCA(n_components=2)
        pca.fit(X)
        pca_2_codes = pca.transform(X)
        colors = np.linspace(0, 2 * np.pi, N_THETA)

        fig, ax = plt.subplots(figsize=(7, 7))
        ax.scatter(pca_2_codes[:, 0], pca_2_codes[:, 1], c=colors, alpha=0.7)
        ax.set_xlabel('Dim 0')
        ax.set_ylabel('Dim 1')
        ax.set_title('2D Projection of the representations')
        fig.savefig(self.args.results_dir + f'/Representations_{axis}.png', dpi=600)
        print(f'Plot of visualization along {axis} axis is saved at {self.args.results_dir}...')
        plt.close()





