import argparse
import sys
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
import numpy as np
import torch as th
import torch.nn as nn
import cv2
import os
import pickle
import torch.distributed as dist
from utils.utils import Gaus2D
from einops import rearrange, repeat, reduce
from utils.configuration import Configuration
from model.lightning.pretrainer import LociPretrainerModule
from data.lightning_objects import LociPretrainerDataModule
plt.rcParams["text.usetex"] = True

class LociPlayground:
    
    def __init__(self, cfg, device, file, gestalt = None, position = None):

        device = th.device(device)

        self.device = device
        self.data_module = LociPretrainerDataModule(cfg)
        self.dataloader  = self.data_module.val_dataloader()
        self.data_iter   = iter(self.dataloader)

        self.cfg      = cfg
        self.gestalt  = th.zeros((1, cfg.model.encoder.gestalt_size)).to(device)
        self.position = th.tensor([[0,0,0,0.05]]).to(device)
        self.capacity = th.zeros((1, cfg.model.encoder.gestalt_size)).to(device) # TODO display capacitiy !!! TODO maybee create a plot for capacity mu and sigma (per index)
        self.size     = cfg.model.input_size
        self.gaus2d   = Gaus2D(cfg.model.input_size).to(device)
        self.gestalt_gridcell_size   = 25
        self.gestalt_gridcell_margin = 5
        self.gestalt_grid_width      = 32
        self.gestalt_grid_height     = 8

        self.states = []

        if gestalt is not None:
            self.gestalt = th.from_numpy(gestalt).to(device)

        if position is not None:
            self.position = th.from_numpy(position).to(device)

        self.model = LociPretrainerModule.load_from_checkpoint(file, cfg=cfg).to(device)
        #self.model = LociPretrainerModule(cfg=cfg).to(device)
        self.model.eval()

        self.fig = plt.figure(figsize=(6,6))

        self.ax_slider   = plt.subplot2grid((7, 3), (0, 0), colspan=2, rowspan=1)
        self.ax_lineplot = plt.subplot2grid((7, 3), (0, 2), rowspan=1)
        self.ax_gestalt  = plt.subplot2grid((7, 3), (1, 0), colspan=2, rowspan=2)
        self.ax_position = plt.subplot2grid((7, 3), (1, 2), rowspan=2)
        self.ax_output1  = plt.subplot2grid((7, 3), (3, 0), colspan=2, rowspan=4)
        self.ax_output2  = plt.subplot2grid((7, 3), (3, 2), rowspan=2)
        self.ax_output3  = plt.subplot2grid((7, 3), (5, 2), rowspan=2)

        self.slider = Slider(self.ax_slider, '', -1, 1, valinit=0)
        self.slider.on_changed(self.update_slider)
        self.selected_cell = (0, 0)

        x_data = np.linspace(1, self.cfg.model.encoder.gestalt_size, self.cfg.model.encoder.gestalt_size)
        y_data = np.random.rand(self.cfg.model.encoder.gestalt_size)

        self.line1, = self.ax_lineplot.plot(x_data, y_data, label=r"$\sigma$", color='g')  # Green line
        self.line2, = self.ax_lineplot.plot(x_data, y_data, label=r"$C$", color='b')  # Blue line
        self.line3, = self.ax_lineplot.plot(x_data, y_data, label=r"$M$", color='r')  # Red line
        self.plot_big = False

        self.outputs = [self.ax_output1, self.ax_output2, self.ax_output3]
        self.indices = [0, 1, 2]
        self.data_index = -1
        self.current_index = -1

        self.connections = ()

        self.compute_latent()
        self.update_outputs()

        plt.tight_layout()

    def compute_capacity(self, capacity):
        intervall = th.linspace(-3, 3, self.cfg.model.encoder.gestalt_size).to(self.device).view(1, -1)
        capacity = th.tanh(capacity)*2.99
        capacity = (3 / (capacity + 3)) * (intervall - capacity)
        capacity = th.exp(th.clamp(capacity, -25, 25))
        return 1.0949 / (1 + capacity) - 0.0474

    def compute_latent(self):
        data = [d.to(self.device) for d in next(self.data_iter)]
        self.input_rgb = data[0]
        self.input_depth = data[1]
        self.input_instance_mask = data[2]

        with th.no_grad():
            results = self.model(*data)
            self.gestalt = results['gestalt']
            self.position = results['position']

        mu, sigma, capacity, mask = self.model.get_cvae_state()
        self.capacity = self.compute_capacity(capacity) 
        self.mask = mask
        self.mu = mu
        self.sigma = sigma

        self.states.append((self.gestalt.clone(), self.position.clone(), self.mu, self.sigma, self.capacity))

        x_data = np.linspace(1, self.cfg.model.encoder.gestalt_size, self.cfg.model.encoder.gestalt_size)
        self.line1.set_data(x_data, sigma.cpu().numpy()[0])
        self.line2.set_data(x_data, self.capacity.cpu().numpy()[0])
        self.line3.set_data(x_data, mask.cpu().numpy()[0])

        # Adjust the axis limits and redraw
        self.ax_lineplot.relim()
        self.ax_lineplot.autoscale_view()
        self.fig.canvas.draw()

        h, w = self.selected_cell
        i = h * self.gestalt_grid_width + w
        self.slider.set_val(self.gestalt[0,i].item())
        print(sigma.shape, capacity.shape, self.gestalt.shape)
        self.ax_slider.set_title(r"$\sigma = {:.2e}, C = {:.2f}$".format(sigma[0,i].item(), self.capacity[0,i].item()))

        self.add_image(self.ax_gestalt, self.create_gestalt_image())
        self.add_image(self.ax_position, self.create_position_image())
        self.data_index += 1
        self.current_index += 1

    def set_state(self, i):
        self.gestalt, self.position, self.mu, self.sigma, self.capacity = self.states[i]
        self.gestalt = self.gestalt.clone()
        self.position = self.position.clone()

        x_data = np.linspace(1, self.cfg.model.encoder.gestalt_size, self.cfg.model.encoder.gestalt_size)
        self.line1.set_data(x_data, self.sigma.cpu().numpy()[0])
        self.line2.set_data(x_data, self.capacity.cpu().numpy()[0])
        self.line3.set_data(x_data, self.mask.cpu().numpy()[0])

        # Adjust the axis limits and redraw
        self.ax_lineplot.relim()
        self.ax_lineplot.autoscale_view()
        self.fig.canvas.draw()

        h, w = self.selected_cell
        i = h * self.gestalt_grid_width + w
        self.slider.set_val(self.gestalt[0,i].item())

        self.add_image(self.ax_gestalt, self.create_gestalt_image())
        self.add_image(self.ax_position, self.create_position_image())
        self.update_outputs()

    def on_key_press(self, event):
        if event.key == 'n':
            if self.current_index == self.data_index:
                self.compute_latent()
                self.update_outputs()
                self.fig.canvas.draw()
            else:
                self.current_index += 1
                self.set_state(self.current_index)

        if event.key == 'b':
            if self.current_index > 0:
                self.current_index -= 1
                self.set_state(self.current_index)

        if event.key == 'r':
            self.set_state(self.current_index)


    def update_outputs(self):
        mask, object, depth = None, None, None
        with th.no_grad():
            if self.cfg.pretraining_mode == 'mask':
                mask = self.model.net.mask_pretrainer.decode(self.gestalt, self.position)
                mask = th.softmax(th.cat((mask, th.ones_like(mask)), dim=1), dim=1)[:,:1]

                self.add_image(self.outputs[self.indices[0]], mask.cpu().numpy()[0].transpose(1,2,0))
                #self.add_image(self.outputs[self.indices[1]], object.cpu().numpy()[0].transpose(1,2,0))
                #self.add_image(self.outputs[self.indices[2]], mask.cpu().numpy()[0].transpose(1,2,0))

            if self.cfg.pretraining_mode == 'depth':
                depth = self.model.net.depth_pretrainer.decode(self.position, self.gestalt, self.input_instance_mask)
                depth = th.sigmoid(depth) * self.input_instance_mask

                self.add_image(self.outputs[self.indices[0]], depth.cpu().numpy()[0].transpose(1,2,0))
                self.add_image(self.outputs[self.indices[1]], self.input_instance_mask.cpu().numpy()[0].transpose(1,2,0))
                self.add_image(self.outputs[self.indices[2]], self.input_rgb.cpu().numpy()[0].transpose(1,2,0))

    def update_slider(self, val):
        if self.selected_cell is not None:
            h, w = self.selected_cell
            i = h * self.gestalt_grid_width + w
            self.gestalt[0, i] = val
            self.add_image(self.ax_gestalt, self.create_gestalt_image())
            self.update_outputs()
            self.fig.canvas.draw()

    def __enter__(self):
        self.connections = (
            self.fig.canvas.mpl_connect('button_press_event', self.onclick),
            self.fig.canvas.mpl_connect('scroll_event', self.onscroll),
            self.fig.canvas.mpl_connect('key_press_event', self.on_key_press),
        )
        return self

    def __exit__(self, *args, **kwargs):
        for connection in self.connections:
            self.fig.canvas.mpl_disconnect(connection)

    def create_gestalt_image(self):

        gestalt = self.gestalt[0].cpu().numpy() * 0.5 + 0.5
        capacity = self.capacity[0].cpu().numpy()
        size = self.gestalt_gridcell_size
        margin = self.gestalt_gridcell_margin
            
        width = self.gestalt_grid_width * (margin + size) + margin
        height = self.gestalt_grid_height * (margin + size) + margin
        img = np.zeros((height, width, 3)) + 0.3
        
        for i in range(gestalt.shape[0]):
            h = i // self.gestalt_grid_width
            w = i % self.gestalt_grid_width
            
            if self.selected_cell == (h, w):
                top    = h * (size + margin)
                left   = w * (size + margin)
                bottom = top + size + 2 * margin
                right  = left + size + 2 * margin

                img[top:bottom, left:right, 0] = 0.6
                img[top:bottom, left:right, 1] = 0.6
                img[top:bottom, left:right, 2] = 0.6

            top    = h * (size + margin) + margin
            left   = w * (size + margin) + margin
            bottom = top + size
            right  = left + size
            

            img[top:bottom, left:right, 0] = ((1 - gestalt[i]) * 0.8 + gestalt[i] * 0.2) * capacity[i] #+ 0.5 * (1 - capacity[i])
            img[top:bottom, left:right, 1] = (gestalt[i] * 0.8 + (1 - gestalt[i]) * 0.2) * capacity[i] #+ 0.5 * (1 - capacity[i])
            img[top:bottom, left:right, 2] = 0.2 * capacity[i] #+ 0.5 * (1 - capacity[i])

        return img

    def create_position_image(self):
        
        img = self.gaus2d(self.position, compute_std=self.cfg.pretraining_mode == 'all')
        img = rearrange(img[0], 'c h w -> h w c')

        return th.cat((img, img, img * 0.6 + 0.4), dim=2).cpu().numpy()


    def add_image(self, ax, img):
        ax.clear()
        ax.imshow(img)
        ax.axis('off')

    def swap_layout(self):
        # Remove current plots from the canvas
        for ax in [self.ax_lineplot, self.ax_gestalt]:
            ax.remove()

        # Swap the positions of the line plot and gestalt plot
        if self.plot_big:
            self.ax_lineplot = plt.subplot2grid((7, 3), (0, 2), rowspan=1)
            self.ax_gestalt  = plt.subplot2grid((7, 3), (1, 0), colspan=2, rowspan=2)
        else:
            self.ax_gestalt  = plt.subplot2grid((7, 3), (0, 2), rowspan=1)
            self.ax_lineplot = plt.subplot2grid((7, 3), (1, 0), colspan=2, rowspan=2)

        self.plot_big = not self.plot_big

        # Re-plot the data on the new axes
        x_data = np.linspace(1, self.cfg.model.encoder.gestalt_size, self.cfg.model.encoder.gestalt_size)
        self.line1, = self.ax_lineplot.plot(x_data, self.sigma.cpu().numpy()[0], label=r"$\sigma$", color='g')  # Green line
        self.line2, = self.ax_lineplot.plot(x_data, self.capacity.cpu().numpy()[0], label=r"$C$", color='b')  # Blue line
        self.line3, = self.ax_lineplot.plot(x_data, self.mask.cpu().numpy()[0], label=r"$M$", color='r')  # Red line
        if self.plot_big:
            self.ax_lineplot.grid(True)
            self.ax_lineplot.set_title("Compressing VAE")
            self.ax_lineplot.set_xlabel("channels")
            self.ax_lineplot.set_ylabel(r"$\sigma, C$")
            self.ax_lineplot.legend()  
        self.add_image(self.ax_gestalt, self.create_gestalt_image())

        # Redraw the canvas
        self.fig.canvas.draw()

    def onclick(self, event):
        x, y = event.xdata, event.ydata

        if self.ax_gestalt == event.inaxes:

            size    = self.gestalt_gridcell_size
            margin  = self.gestalt_gridcell_margin

            w = int(x / (margin + size))
            h = int(y / (margin + size))
            self.selected_cell = (h, w)

            i = h * self.gestalt_grid_width + w
            self.slider.set_val(self.gestalt[0,i].item())
            self.ax_slider.set_title(r"$\sigma = {:.2e}, C = {:.2f}$".format(self.sigma[0,i].item(), self.capacity[0,i].item()))

            self.add_image(self.ax_gestalt, self.create_gestalt_image())
            self.update_outputs()
            self.fig.canvas.draw()

        if self.ax_lineplot == event.inaxes:
            self.swap_layout()

        if self.ax_position == event.inaxes:

            x = (x / self.size[1]) * 2 - 1
            y = (y / self.size[0]) * 2 - 1

            self.position[0,0] = x
            self.position[0,1] = y

            self.add_image(self.ax_position, self.create_position_image())
            self.update_outputs()
            self.fig.canvas.draw()

        if self.ax_output2 == event.inaxes:
            ax_tmp = self.indices[0]
            self.indices[0] = self.indices[1]
            self.indices[1] = ax_tmp
            self.update_outputs()
            self.fig.canvas.draw()

        if self.ax_output3 == event.inaxes:
            ax_tmp = self.indices[0]
            self.indices[0] = self.indices[2]
            self.indices[2] = ax_tmp
            self.update_outputs()
            self.fig.canvas.draw()
            
    def onscroll(self, event):
        if self.ax_position == event.inaxes:
            i = 3 if self.cfg.pretraining_mode == 'all' else 2
            magnitude = 0.1 if self.cfg.pretraining_mode == 'all' else 0.01
            std = self.position[0,i]
            if event.button == 'down':
                self.position[0,i] = std - magnitude
                
            elif event.button == 'up':
                self.position[0,i] = std + magnitude

            self.add_image(self.ax_position, self.create_position_image())
            self.update_outputs()
            self.fig.canvas.draw()

if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("-cfg", "--cfg", required=True, type=str)
    parser.add_argument("-load", "--load", required=False, type=str)
    parser.add_argument("-latent", "--latent", default="", type=str)
    parser.add_argument("-device", "--device", default=0, type=int)
    parser.add_argument("-seed", "--seed", default=1234, type=int)
    parser.add_argument("-port", "--port", default=29500, type=int)

    args = parser.parse_args(sys.argv[1:])
    cfg  = Configuration(args.cfg)

    th.manual_seed(args.seed)
    np.random.seed(args.seed)
    cfg.seed = args.seed
    cfg.model.batch_size = 1

    os.environ['RANK'] = "0"
    os.environ['WORLD_SIZE'] = str(1)
    os.environ['MASTER_ADDR'] = 'localhost' 
    os.environ['MASTER_PORT'] = str(args.port + args.device)
    dist.init_process_group(backend='nccl', init_method='env://')

    gestalt  = None
    position = None
    if args.latent != "":
        with open(args.latent, 'rb') as infile:
            state = pickle.load(infile)
            gestalt  = state["gestalt"]
            position = state["position"]

    with LociPlayground(cfg, args.device, args.load, gestalt, position):
        plt.show()
