import torch
import torch.nn as nn
from src.utils.env_dataset import EnvDataset


class BaseModel(nn.Module):

    def __init__(self, input_dim, hidden_layer_width, output_dim):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_layer_width),
            nn.Mish(inplace=True),
            nn.Linear(hidden_layer_width, output_dim),
        )

        # self.seq = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(input_dim, output_dim),
        # )

    def forward(self, x):
        return self.seq(x)
    

class NNModel:

    def __init__(
            self,
            num_epochs,
            minibatch_size,
            input_dim,
            output_dim,
            train_inputs=None,
            train_targets=None,
            priorities=None,
            max_dataset_size=1000,
            hidden_layer_width=200,
            learning_rate=0.1,
            device='cpu'
        ):
        self.num_epochs = num_epochs
        self.minibatch_size = minibatch_size
        self.max_dataset_size = max_dataset_size
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lr = learning_rate
        self.device = device
        self.storage_device = 'cpu'

        if train_inputs is None and train_targets is None:
            train_inputs = torch.empty((0, self.input_dim)).to(self.storage_device)
            train_targets = torch.empty((0, self.output_dim)).to(self.storage_device)
            priorities = torch.zeros((0)).to(self.storage_device)
        self.train_dataset = EnvDataset(train_inputs, train_targets)
        self.priorities = priorities

        self.model = BaseModel(input_dim, hidden_layer_width, output_dim).to(self.device)

    def train(self):
        minibatch_size = min(self.minibatch_size, len(self.train_dataset))
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=minibatch_size,
            pin_memory=True,
            drop_last=True
        )
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        self.model.train() # enter training mode
        loss_function = nn.MSELoss()
        losses = []
        for i in range(self.num_epochs):
            epoch_losses = []
            for input_batch, target_batch in train_loader:
                optimizer.zero_grad()
                input_batch = input_batch.to(self.device)
                target_batch = target_batch.to(self.device)
                output = self.model(input_batch)
                loss = loss_function(output, target_batch)
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item())
            loss = sum(epoch_losses) / len(epoch_losses)
            losses.append(loss)
            # print(f"Epoch {i} loss: {loss}")
        print(f"Final NN loss: {loss}")

        return self.num_epochs, losses

    def predict(self, input):
        self.model.eval()
        with torch.no_grad():
            return self.model(input)
    
    def register_new_data(self, inputs, targets, priorities):
        # inputs = self.normalize_input(inputs) # normalize the inputs
        inputs = torch.cat((inputs.to(self.storage_device), self.train_dataset.inputs), dim=0)
        targets = torch.cat((targets.to(self.storage_device), self.train_dataset.targets), dim=0)
        self.priorities = torch.cat((priorities.to(self.storage_device), self.priorities))

        if len(inputs) > self.max_dataset_size: # keep high value samples
            _, indices = torch.sort(self.priorities, descending=True)
            inputs = inputs[indices[:self.max_dataset_size]]
            targets = targets[indices[:self.max_dataset_size]]
            self.priorities = self.priorities[indices[:self.max_dataset_size]]
        
        self.train_dataset.inputs = inputs
        self.train_dataset.targets = targets
        return len(self.train_dataset)

    def save_model(self, path='model_state.pth'):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path='model_state.pth'):
        state_dict = torch.load(path)
        self.model.load_state_dict(state_dict)