import shutil
import torch
import os

import matplotlib as mpl
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm
from PIL import Image

import matplotlib.animation as animation


def plot_field(solver, settings):
    LD = settings['LD']

    mpl.use('Qt5Agg')
    phi = solver.a_matrix(settings['seed'])

    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    X, Y = np.meshgrid(np.arange(LD), np.arange(LD))
    ax.plot_surface(X, Y, phi, cmap=cm.inferno)
    plt.show()


def animate_eq(solver):
    solver.calculate_fields()
    gd = solver.grid_dimension
    t = solver.timesteps
    q = solver.output_dimension
    d_t = solver.delta_t

    u = solver.u.reshape((t, gd, gd))
    u = 2 * ((u - u.min()) / (u.max() - u.min())) - 1

    m_u = solver.m_u.reshape((t, q, q))
    m_u = 2 * ((m_u - m_u.min()) / (m_u.max() - m_u.min())) - 1

    print("Generating animation for u field...")
    print("Saving animation for u field...")
    animate(u, d_t, t, solver.seed, 'u_animated')

    print("Generating animation for m_u field...")
    print("Saving animation for m_u field...")
    animate(m_u, d_t, t, solver.seed, 'm_u_animated')


def animate(u, d_t, t, seed, name, speed_up=1):
    u = u[::speed_up]
    d_t = d_t * speed_up
    t = t // speed_up
    anim = animation.FuncAnimation(plt.figure(),
                                   lambda k: plot_heat_map(u[k], k, d_t, u[:t].min(), u[:t].max()),
                                   frames=t, repeat=False)
    anim.save(f'data/generated/{name}.mp4', fps=60)


def save_video(video, seed, name, fps=60, adjust_clim=True):
    num_frames, height, width = video.shape

    fig, ax = plt.subplots()
    ax.axis("off")  # Turn off axes
    im = ax.imshow(video[0], cmap='inferno_r')
    im.set_clim(vmin=video[0].min(), vmax=video[0].max())

    def update(frame_index):
        frame = video[frame_index]
        im.set_data(frame)
        if adjust_clim:
            im.set_clim(vmin=frame.min(), vmax=frame.max())
            # plt.colorbar(im, ax=ax)
        return [im]

    anim = animation.FuncAnimation(fig, update, frames=num_frames, blit=True)
    writer = animation.FFMpegWriter(fps=fps)

    anim.save(f'{seed}/{name}.mp4', writer=writer)
    plt.close(fig)


def save_tensor(video, name, frame_cut=True, frame_cut_proportion=5):
    if len(video.shape) == 2:
        video = video.view(-1, int(np.sqrt(video.shape[1])), int(np.sqrt(video.shape[1])))
    num_frames, height, width = video.shape
    if frame_cut:
        tensor = torch.asarray(video[num_frames // frame_cut_proportion:], dtype=torch.float32)  # Remove first 1/k frames burn-in time. Only for chaotic behaviour
    else:
        tensor = torch.asarray(video, dtype=torch.float32)
    file_name = f'./data/{name}.pt'

    file_path = os.path.abspath(file_name)  # Convert to absolute path
    dir_name = os.path.dirname(file_path)  # Extract directory part

    if not os.path.exists(dir_name):
        os.makedirs(dir_name, exist_ok=True)  # Create directories if missing
    torch.save(tensor, file_name)


def to_image(tensor):
    tensor = tensor.detach().cpu()
    if len(tensor.shape) == 4:  # Image comes from KS
        tensor = tensor.squeeze(0).squeeze(0) # Remove batch and channel dimensions
    else:  # Image comes from Wave or Heat
        side_length = int(np.sqrt(tensor.shape[0]))
        tensor = tensor.view(side_length, side_length)

    tensor = tensor - tensor.min()
    tensor = tensor / tensor.max()
    tensor = tensor.cpu().numpy()

    if len(tensor.shape) == 2:  # Grayscale image, apply colormap
        colormap = plt.get_cmap('inferno_r')
        tensor = colormap(tensor)[:, :, :3]  # Apply colormap and remove alpha channel
        tensor = (tensor * 255).astype(np.uint8)

    return Image.fromarray(tensor)


def plot_heat_map(u_k, k, delta_t, v_min, v_max):
    # Clear the current plot figure
    plt.clf()

    plt.title(f"Wave amplitude at t = {k * delta_t:.1f} unit time")
    # plt.title(f"Temperature at t = {k * delta_t:.1f} unit time")
    plt.xlabel("x")
    plt.ylabel("y")

    plt.pcolormesh(u_k, cmap=plt.cm.inferno, vmin=v_min, vmax=v_max)
    plt.colorbar()

    return plt


def save_loss_curves(data, seed):
    shutil.copy(f'config.py', f'output/{seed}/config.txt')
    plot_mean_var_curve(data, f'output/{seed}/stats.png')


def plot_err_curves(data, seed):
    plt.figure()
    for i in range(len(data)):
        plt.plot(data[i], label=f"Length {i}")
    plt.xlabel('Epoch')
    plt.ylabel('Err')
    plt.title(f'Error metric against epoch for different history lengths')
    plt.legend()
    plt.savefig(f'output/{seed}/history_curves.png')
    plt.close()


def plot_err_against_length(data, seed):
    plt.figure()
    curve = []
    for d in data:
        curve.append(d[-1])
    plt.plot(curve)
    plt.xlabel('Length')
    plt.ylabel('Err')
    plt.title(f'Error metric against history length')
    plt.savefig(f'output/{seed}/history_err.png')
    plt.close()


def plot_sample_size_err(data, lengths, parameter, path):
    plt.figure()
    last_elements = []
    for i in range(lengths):  # Number of values of the parameter
        param_i = []
        for sample in data:  # Number of samples
            param_i.append(sample[i][-1])
        last_elements.append(param_i)
    avg, var = get_statistics(last_elements)
    plt.errorbar(np.arange(len(avg)), avg, yerr=var, label='Mean')
    plt.xlabel(parameter)
    plt.ylabel('Err')
    plt.title(f'Error metric against {parameter}. Sample size: {len(data)}')
    plt.savefig(f'output/{path}.png')
    plt.close()


def plot_mean_var_curve(data, path):
    plt.figure()
    avg, var = get_statistics(data)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.title(f'Loss against epoch')
    plt.errorbar(np.arange(len(avg)), avg, yerr=var, label='Mean')
    plt.legend()
    plt.savefig(path)
    plt.close()


def get_statistics(data):
    data = np.array(data)
    avg = np.mean(data, axis=1)
    var = np.var(data, axis=1)
    return avg, var
