import torch
import numpy as np
import torch.nn as nn
import threading


class SmallNet(nn.Module):
    def __init__(self, net_in_dim, inner_dim, net_out_dim, device):
        super().__init__()
        self.net_in_dim = net_in_dim
        self.inner_dim = inner_dim
        self.net_out_dim = net_out_dim
        self.device = device
        self.net = nn.Sequential(
            nn.Linear(net_in_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, inner_dim),
            nn.Tanh(),
            nn.Linear(inner_dim, net_out_dim),
        )
        self.to(device)

    def forward(self, x):
        x = torch.from_numpy(x).to(dtype=torch.float32, device=self.device)

        return self.net(x)


class SubThread(threading.Thread):
    def __init__(self, device):
        threading.Thread.__init__(self)
        self.net_in_dim = 64
        self.inner_dim = 256
        self.net_out_dim = 128
        self.device = device
        self.sub_net = SmallNet(
            net_in_dim=self.net_in_dim,
            inner_dim=self.inner_dim,
            net_out_dim=self.net_out_dim,
            device=self.device,
        )
        self.batch_size = 64
        self._stop_run = False
        self.optim = torch.optim.Adam(
            self.sub_net.parameters(), lr=1e-4, eps=1e-2,
        )

    def run(self):
        while not self._stop_run:
            input = np.random.random((self.batch_size, self.net_in_dim))
            target = torch.from_numpy(np.random.random((self.batch_size, self.net_out_dim))).to(self.device)
            loss = torch.mean((self.sub_net(input) - target) ** 2)
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

    def stop_run(self):
        self._stop_run = True
