# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py

import random
import threading
import time
from typing import List, Union

import nni.nas.execution.api
import nni.nas.nn.pytorch as nn
import nni.nas.strategy as strategy
import pytest
import torch
import torch.nn.functional as F
from nni.nas.execution import wait_models
from nni.nas.execution.common import (
    AbstractExecutionEngine,
    AbstractGraphListener,
    DebugEvaluator,
    MetricData,
    Model,
    ModelStatus,
    WorkerInfo,
)
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.nn.pytorch.mutator import process_inline_mutation


class MockExecutionEngine(AbstractExecutionEngine):

    def __init__(self, failure_prob=0.):
        self.models = []
        self.failure_prob = failure_prob
        self._resource_left = 4

    def _model_complete(self, model: Model):
        time.sleep(random.uniform(0, 1))
        if random.uniform(0, 1) < self.failure_prob:
            model.status = ModelStatus.Failed
        else:
            model.metric = random.uniform(0, 1)
            model.status = ModelStatus.Trained
        self._resource_left += 1

    def submit_models(self, *models: Model) -> None:
        for model in models:
            self.models.append(model)
            self._resource_left -= 1
            threading.Thread(target=self._model_complete, args=(model, )).start()

    def list_models(self) -> List[Model]:
        return self.models

    def query_available_resource(self) -> Union[List[WorkerInfo], int]:
        return self._resource_left

    def budget_exhausted(self) -> bool:
        pass

    def register_graph_listener(self, listener: AbstractGraphListener) -> None:
        pass

    def trial_execute_graph(cls) -> MetricData:
        pass


def _reset_execution_engine(engine=None):
    nni.nas.execution.api._execution_engine = engine


class Net(nn.Module):

    def __init__(self, hidden_size=32, diff_size=False):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.LayerChoice(
            [
                nn.Linear(4 * 4 * 50, hidden_size, bias=True),
                nn.Linear(4 * 4 * 50, hidden_size, bias=False)
            ],
            label='fc1'
        )
        self.fc2 = nn.LayerChoice(
            [
                nn.Linear(hidden_size, 10, bias=False),
                nn.Linear(hidden_size, 10, bias=True)
            ] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
            label='fc2'
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def _get_model_and_mutators(**kwargs):
    base_model = Net(**kwargs)
    script_module = torch.jit.script(base_model)
    base_model_ir = convert_to_graph(script_module, base_model)
    base_model_ir.evaluator = DebugEvaluator()
    mutators = process_inline_mutation(base_model_ir)
    return base_model_ir, mutators


@pytest.mark.skip(
    reason="NNI currently uses OpenAI Gym"
)  # TODO: Remove once NNI transitions to Gymnasium
def test_rl():
    rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
    engine = MockExecutionEngine(failure_prob=0.2)
    _reset_execution_engine(engine)
    rl.run(*_get_model_and_mutators(diff_size=True))
    wait_models(*engine.models)
    _reset_execution_engine()

    rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
    engine = MockExecutionEngine(failure_prob=0.2)
    _reset_execution_engine(engine)
    rl.run(*_get_model_and_mutators())
    wait_models(*engine.models)
    _reset_execution_engine()


if __name__ == '__main__':
    test_rl()
