import io
import os.path

import cairosvg
import torch
from PIL import Image

import chess
import chess.svg
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

from torch.utils.tensorboard import SummaryWriter

import config

_logging_path = None
_writer = None


def init(log_path=os.path.join("..", "tensorboard"), name="temp"):
    global _logging_path
    _logging_path = os.path.join(log_path, name)
    if not os.path.exists(_logging_path):
        os.makedirs(_logging_path)
    global _writer
    _writer = SummaryWriter(_logging_path)


def close():
    if _writer is not None:
        _writer.close()


def _add_square_dist(pref, board_tensor, iteration, name, pt=None):
    scalar = 2 * board_tensor.sum(dim=0).mean()
    if pt is None:
        dist = board_tensor.sum(dim=0)
    else:
        dist = board_tensor[pt]
    cmap = plt.get_cmap('plasma')
    cdist = cmap((dist / scalar).cpu())
    _writer.add_image(f"{pref}/pieces/squares/{name}/", cdist, iteration, dataformats='HWC')
    cdist = cmap((dist / (dist.max() + 1e-5)).cpu())
    _writer.add_image(f"{pref}/pieces/squares/{name}/normalized", cdist, iteration, dataformats='HWC')


def add_board_tensor_dist(pref, board_tensor, iteration):
    if _writer is None:
        init()
    board_tensor = board_tensor.view(-1, 12, 64)
    count = board_tensor.shape[0]
    board_tensor = board_tensor.sum(dim=0)

    fig, ax = plt.subplots()

    counts = ((board_tensor.sum(dim=1) / count).cpu().numpy())
    piece_names = [
        "W Pawn", "W Knight", "W Bishop", "W Rook", "W Queen", "W King",
        "B Pawn", "B Knight", "B Bishop", "B Rook", "B Queen", "B King"]
    bar_colors = ['tab:red', 'tab:red', 'tab:red', 'tab:red', 'tab:red', 'tab:red',
                  'tab:blue', 'tab:blue', 'tab:blue', 'tab:blue', 'tab:blue', 'tab:blue']

    ax.bar(piece_names, counts, color=bar_colors)

    ax.set_xticks(piece_names, labels=piece_names, rotation=45, ha="right")
    ax.set_ylabel('Frequency')
    ax.set_ylim([0, 3])
    ax.set_title('Piece Type Frequency')

    plt.tight_layout()
    _writer.add_figure(f"{pref}/pieces/types", fig, iteration)

    board_tensor = board_tensor.view(12, 8, 8) / count
    _add_square_dist(pref, board_tensor, iteration, "all")
    [_add_square_dist(pref, board_tensor, iteration, name, pt) for pt, name in enumerate(piece_names)]
    # _writer.add_histogram(f"{pref}piece", board_tensor.sum(dim=1) / count, iteration)


def add_histogram(name, square_hist, iteration):
    if _writer is None:
        init()
    _writer.add_histogram(name, square_hist, iteration)


def add_scalar(name, value, iteration):
    if _writer is None:
        init()
    _writer.add_scalar(name, value, iteration)


def log_boards(name, boards, iteration):
    if len(boards) == 0:
        return
    if _writer is None:
        init()
    if len(boards) == 7:
        boards = boards[:-1]
    board_size = 160 if len(boards) > 2 else 320
    svgs = [chess.svg.board(board, size=board_size) for board in boards]

    def svg_to_png(svg_data):
        # Convert SVG to PNG bytes
        png_bytes = cairosvg.svg2png(bytestring=svg_data)
        # Open the PNG as a PIL image and convert to numpy array
        png_image = Image.open(io.BytesIO(png_bytes))

        if png_image.mode == 'RGBA':
            png_image = png_image.convert('RGB')

        return np.array(png_image)

    png_arrays = [np.transpose(svg_to_png(svg), (2, 0, 1)) for svg in svgs]
    png_arrays = np.stack(png_arrays, axis=0)
    if len(boards) > 5:
        png_arrays = np.concatenate([png_arrays[0::2], png_arrays[1::2]], axis=2)
    _writer.add_images(name, png_arrays, iteration)


def save_fens(fens, name, iteration, path="fens"):
    if len(fens) == 0:
        return
    global _logging_path
    fens_path = _logging_path.replace("tensorboard", path)
    if not os.path.exists(fens_path):
        os.makedirs(fens_path)
    with open(os.path.join(fens_path, f"{name}_{iteration}.csv"), 'w') as f:
        for fen in fens:
            if isinstance(fen, str):
                f.write(f"{fen}, \n")
            else:
                assert isinstance(fen, tuple)
                f.write(f"{fen[0]}, {fen[1]}, \n")


def save_model(model, name):
    global _logging_path
    model_path = _logging_path.replace("tensorboard", "models")
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    torch.save(model.state_dict(), os.path.join(model_path, name))


def save_checkpoint(model, optimizer, iteration, temperature, name):
    global _logging_path
    check_path = _logging_path.replace("tensorboard", "checkpoints")
    if not os.path.exists(check_path):
        os.makedirs(check_path)
    checkpoint_dict = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'iteration': iteration,
        'temperature': temperature,
        'continuations': config.continuations
    }
    torch.save(checkpoint_dict, os.path.join(check_path, name))


def load_checkpoint(model, optimizer, name, checkpoint_path=os.path.join("..", "checkpoints")):
    checkpoint = torch.load(os.path.join(checkpoint_path, name), weights_only=True)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    iteration = checkpoint['iteration']
    temperature = checkpoint['temperature']
    if 'continuations' in checkpoint:
        config.continuations = checkpoint['continuations'] + 1
    else:
        config.continuations = 1
    print(f"Loaded iteration {iteration} with temperature {temperature}. Continuation #{config.continuations}")
    return iteration, temperature
