from itertools import product

import pytest
import torch
from ding.torch_utils import is_differentiable

from lzero.model.alphazero_model import PredictionNetwork

action_space_size = [2, 3]
batch_size = [100, 200]
num_res_blocks = [3]
num_channels = [3]
value_head_channels = [8]
policy_head_channels = [8]
fc_value_layers = [[
    16,
]]
fc_policy_layers = [[
    16,
]]
output_support_size = [2]
observation_shape = [1, 3, 3]

prediction_network_args = list(
    product(
        action_space_size,
        batch_size,
        num_res_blocks,
        num_channels,
        value_head_channels,
        policy_head_channels,
        fc_value_layers,
        fc_policy_layers,
        output_support_size,
    )
)


@pytest.mark.unittest
class TestAlphaZeroModel:

    def output_check(self, model, outputs):
        if isinstance(outputs, torch.Tensor):
            loss = outputs.sum()
        elif isinstance(outputs, list):
            loss = sum([t.sum() for t in outputs])
        elif isinstance(outputs, dict):
            loss = sum([v.sum() for v in outputs.values()])
        is_differentiable(loss, model)

    @pytest.mark.parametrize(
        'action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels, fc_value_layers, fc_policy_layers, output_support_size',
        prediction_network_args
    )
    def test_prediction_network(
            self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels,
            policy_head_channels,
            fc_value_layers, fc_policy_layers, output_support_size
    ):
        obs = torch.rand(batch_size, num_channels, 3, 3)
        flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
        flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2]
        # print('='*20)
        # print(batch_size, num_res_blocks, num_channels, action_space_size, fc_value_layers, fc_policy_layers, output_support_size)
        # print('='*20)
        prediction_network = PredictionNetwork(
            action_space_size=action_space_size,
            continuous_action_space=False,
            num_res_blocks=num_res_blocks,
            num_channels=num_channels,
            value_head_channels=value_head_channels,
            policy_head_channels=policy_head_channels,
            fc_value_layers=fc_value_layers,
            fc_policy_layers=fc_policy_layers,
            output_support_size=output_support_size,
            flatten_output_size_for_value_head=flatten_output_size_for_value_head,
            flatten_output_size_for_policy_head=flatten_output_size_for_policy_head,
            last_linear_layer_init_zero=True,
        )
        policy, value = prediction_network(obs)
        assert policy.shape == torch.Size([batch_size, action_space_size])
        assert value.shape == torch.Size([batch_size, output_support_size])


if __name__ == "__main__":
    action_space_size = 2
    batch_size = 100
    num_res_blocks = 3
    num_channels = 3
    reward_head_channels = 2
    value_head_channels = 8
    policy_head_channels = 8
    fc_value_layers = [16]
    fc_policy_layers = [16]
    output_support_size = 2
    observation_shape = [1, 3, 3]
    obs = torch.rand(batch_size, num_channels, 3, 3)
    flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
    flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2]
    print('=' * 20)
    print(
        batch_size, num_res_blocks, num_channels, action_space_size, reward_head_channels, fc_value_layers,
        fc_policy_layers, output_support_size
    )
    print('=' * 20)
    prediction_network = PredictionNetwork(
        action_space_size=action_space_size,
        num_res_blocks=num_res_blocks,
        num_channels=num_channels,
        value_head_channels=value_head_channels,
        policy_head_channels=policy_head_channels,
        fc_value_layers=fc_value_layers,
        fc_policy_layers=fc_policy_layers,
        output_support_size=output_support_size,
        flatten_output_size_for_value_head=flatten_output_size_for_value_head,
        flatten_output_size_for_policy_head=flatten_output_size_for_policy_head,
        last_linear_layer_init_zero=True,
    )
    policy, value = prediction_network(obs)
    assert policy.shape == torch.Size([batch_size, action_space_size])
    assert value.shape == torch.Size([batch_size, output_support_size])
