import numpy as np
import torch
from matplotlib import pyplot as plt


op_names = ['MBConv3X3_3', 'MBConv3X3_6', 'MBConv5X5_3', 'MBConv5X5_6', 'MBConv7X7_3', 'MBConv7X7_6', 'Zero']


class SingleLayerPlotter():
    def __init__(self, operation_num):
        self.probs = []
        for i in range(operation_num):
            self.probs.append([])
    
    def add_iter_result(self, iter_results):
        for prob, iter_result in zip(self.probs, iter_results):
            prob.append(iter_result.item())
    
    def plot_ops(self, axes, stage, layer):
        for i, prob in enumerate(self.probs):
            axes.plot(range(1, len(prob) + 1), prob, label=op_names[i])
            axes.set_xlabel('iteration')
            axes.set_ylabel('probability')
            axes.legend(loc='upper left')
            axes.set_title(f'stage_{stage}, layer_{layer}')


class Plotter():
    def __init__(self, stage_num, layer_num):
        self.stage_num = stage_num
        self.layer_num = layer_num
        self.plotters = []
        for i in range(stage_num):
            for j in range(layer_num):
                if j == 0:
                    self.plotters.append(SingleLayerPlotter(len(op_names) - 1))
                else:
                    self.plotters.append(SingleLayerPlotter(len(op_names)))

    def add_iter_results(self, model):
        for plotter, alpha in zip(self.plotters, model.architecture_parameters()):
            plotter.add_iter_result(torch.softmax(alpha, dim=0))

    def plot_layers(self, path):
        plt.rcParams["font.size"] = 20
        fig, axs = plt.subplots(self.stage_num, self.layer_num, figsize=(self.layer_num * 10, self.stage_num * 10))
        for i in range(self.stage_num):
            for j in range(self.layer_num):
                self.plotters[i * self.layer_num + j].plot_ops(axs[i, j], i, j)
        plt.savefig(path)

