import pathlib
from multiprocessing import Process
import copy
import torch.utils.tensorboard
import datetime
import matplotlib.pyplot as plt

import gmc.fitting
from gmc.model import Layer, Config as ModelConfig
import gmc.render

import mnist_classification.main as main
from mnist_classification.config import Config

# device = list(sys.argv)[1]
device = "cuda"

c: Config = Config()
c.input_fitting_iterations = 100
c.input_fitting_components = 64
c.n_epochs = 32
c.batch_size = 40
c.log_interval = 1000

c.model.layers = [Layer(8, 1.5, 32),
                  Layer(10, 2.5, -1)]

tensor_board_writer = torch.utils.tensorboard.SummaryWriter(c.data_base_path / f'tensorboard_' / f'badfitting2_{datetime.datetime.now().strftime("%m%d_%H%M")}')

def show(mixture, name):
    imagesize = 400
    pos_range = 5
    rendering = gmc.render.render(mixture, torch.zeros(1, 1),
                                  x_low=-pos_range, y_low=-pos_range, x_high=pos_range, y_high=pos_range,
                                  width=imagesize, height=imagesize)
    rendering = gmc.render.colour_mapped(rendering.view(imagesize, imagesize).numpy(), -1.5, 1.5)
    plt.imshow(rendering[:, :, :3])
    plt.show()
    tensor_board_writer.add_image(name, rendering[:, :, :3], 0, dataformats='HWC')
    tensor_board_writer.add_scalar("tt", 32, 0)


mixture = torch.tensor([[ 2.5,  0.0, -1.0,  4.0,  0.00,  0.00, 4.0],
                        [-2.5,  0.0,  1.0,  4.0,  0.00,  0.00, 4.0]]).view(1, 1, -1, 7)


fitting, ret_const, [initial_fitting, fp_fitting] = gmc.fitting.fixed_point_and_tree_hem2(mixture, torch.zeros(1, 1), 4)
show(mixture, "mixture")
# show(initial_fitting)
show(fp_fitting, "fp_fitting")
show(fp_fitting, "fp_fitting2")
# show(fitting)

