import traceback
import unittest

import torch

from Models import rnn_models
from config.experiments import all_models

model_size = 64
input_size = 16
batch_size = 128
seq_len = 10
test_steps = 3


def test_model(model_type, device='cpu'):
    model_cls = getattr(rnn_models, model_type)
    rnn = model_cls(units=model_size, input_size=input_size).to(device)
    optimizer = torch.optim.Adam(rnn.parameters())
    for _ in range(test_steps):
        out = rnn(torch.zeros((batch_size, seq_len, input_size), device=device))
        if isinstance(out, tuple):
            out = out[0]
        optimizer.zero_grad()
        loss = torch.nn.MSELoss()(out[:,-1], torch.ones((batch_size, model_size),
                                                         device=device))  # Logging to TensorBoard by default
        loss.backward()
        optimizer.step()


class TestModels(unittest.TestCase):

    def test_all_models(self):
        for model_type in all_models:
            with self.subTest(model_type=model_type, device='cpu'):
                test_model(model_type, device='cpu')
                print("Success.")

            with self.subTest(model_type=model_type, device='cuda'):
                if torch.cuda.is_available():
                    test_model(model_type, device='cuda')
                    print("Success.")
                else:
                    print("Skipping, CUDA not available!")
