import torch
import torch.nn as nn
from vanilla_vae import VanillaVAE
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", type=str, default='RoadRunner-v5')
    parser.add_argument('--no-state', default=False, action='store_true')

    parser.add_argument("--load-dataset-name", type=str, default=None)
    parser.add_argument("--load-VAE-name", type=str, default=None)
    parser.add_argument("--save-model-name", type=str, default='')
    args = parser.parse_args()
    return args

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

if __name__ == '__main__':
    args = parse_args()
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.load_VAE_name is None:
        args.load_VAE_name = args.env_id+'_'
    if args.load_dataset_name is None:
        args.load_dataset_name = args.env_id+'_'
    if args.no_state:
        args.save_model_name = 'nostate_'+args.save_model_name

    if args.env_id == 'RoadRunner-v5':
        num_actions = 18
    elif args.env_id == 'Riverraid-v5':
        num_actions = 18
    elif args.env_id == 'SpaceInvaders-v5':
        num_actions = 6
    else:
        raise TypeError('env-id not recognized')

    # Hyper-parameters
    num_classes = num_actions
    hidden_size = 64
    num_epochs = 1000
    batch_size = 128
    learning_rate = 0.001


    # build dataset
    # Read and create dataset
    dataset = torch.load('dataset/'+args.load_dataset_name)

    if args.env_id == 'RoadRunner-v5':
        input_size = num_actions * 3
        if args.no_state:
            fake_state_feature = torch.zeros(num_actions).to(device)
    elif args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
        input_size = (num_actions+1) * 3
        if args.no_state:
            fake_state_feature = torch.zeros(num_actions+1).to(device)

        ids = None
        for task in range(len(dataset['logits'][0])):
            task_id = (torch.ones((len(dataset['logits']),1,1))*(task/len(dataset['logits'][0]))).to(device)
            if ids is None:
                ids = task_id
            else:
                ids = torch.cat((ids,task_id), dim = 1)
        dataset['logits'] = torch.cat((dataset['logits'], ids), dim = 2)
    else:
        raise TypeError('env-id not recognized')


    length = 9000
    dataset_mlp = []

    if not args.no_state:
        VAE = VanillaVAE(4, num_actions, hidden_dims=[32, 64, 128, 256, 512], out_dim_temp=3).to(device)
        VAE.load_state_dict(torch.load('saved_models/VAE_'+args.load_VAE_name))
        VAE.eval()

        with torch.no_grad():
            obs_hidden = VAE(dataset['obs'][:length])[-1]
            if args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
                temp = torch.zeros((len(obs_hidden),1)).to(device)
                obs_hidden = torch.cat((obs_hidden, temp), dim = 1)

    # build MLP dataset
    if args.env_id == 'RoadRunner-v5':
        for i in range(min(length,len(dataset['obs']))):
            if not args.no_state:
                x = torch.cat((dataset['logits'][i][1], dataset['logits'][i][2], obs_hidden[i]))
            else:
                x = torch.cat((dataset['logits'][i][1], dataset['logits'][i][2], fake_state_feature))
            _, y = torch.max(dataset['logits'][i][0:1], 1)
            y = y.squeeze(0)
            dataset_mlp.append((x,y))
    elif args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
        for i in range(min(length,len(dataset['obs']))):
            for t in range(2, len(dataset['logits'][i])):
                if not args.no_state:
                    x = torch.cat((dataset['logits'][i][1], dataset['logits'][i][t], obs_hidden[i]))
                else:
                    x = torch.cat((dataset['logits'][i][1], dataset['logits'][i][t], fake_state_feature))
                _, y = torch.max(dataset['logits'][i][0:1], 1)
                y = y.squeeze(0)
                dataset_mlp.append((x, y))
    else:
        raise TypeError('env-id not recognized')

    train_dataset = test_dataset = dataset_mlp



    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=True)



    model = MLP(input_size, hidden_size, num_classes).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # Move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
        if epoch % 10 == 0:

            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

            torch.save(model.state_dict(), 'saved_models/MLP_'+args.env_id+'_'+args.save_model_name)
            print ('MLP model saved: ', 'saved_models/MLP_'+args.env_id+'_'+args.save_model_name)