import torch
from .generator import Generator
from .estimator import Estimator
from .cosine_temperature import CosineTemperature


class Evaluator():
    def __init__(self, device, generator_lr=0.01, num_blocks=9, num_ops=7, total_epoch=120):
        self.num_ops = num_ops
        self.num_blocks = num_blocks
        self.generator = Generator(num_blocks, num_ops, total_epoch=total_epoch)
        self.estimator = Estimator(num_blocks, num_ops)
        self.generator.to(device)
        self.estimator.to(device)
        self.generator_opt = torch.optim.Adam(self.generator.parameters(), lr=generator_lr)
        self.generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.generator_opt, total_epoch // 3, 0.0001)
        # TODO: arch_temperature: not work when binarize
        # self.arch_temperature = CosineTemperature(eta_max=5, eta_min=0.5, total_epoch=total_epoch)

    def load_estimator(self, estimator_path):
        self.estimator.load_state_dict(torch.load(estimator_path))

