import torch
from torch_geometric.nn import GCNConv, global_mean_pool, aggr
from vanilla_vae import VanillaVAE
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import argparse
from torch.nn import Linear


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", type=str, default='RoadRunner-v5')
    parser.add_argument('--no-state', default=False, action='store_true')

    parser.add_argument("--load-dataset-name", type=str, default=None)
    parser.add_argument("--load-VAE-name", type=str, default=None)
    parser.add_argument("--save-model-name", type=str, default='')
    args = parser.parse_args()
    return args

class GCN_pooling_Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 64)

        self.pool = aggr.GraphMultisetTransformer(
            in_channels=64,
            hidden_channels=64,
            out_channels=64,
            Conv=GCNConv,
            num_nodes=3, #avg_num_nodes,
            pooling_ratio=0.25,
            pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'],
            num_heads=4,
            layer_norm=False,
        )

        self.lin1 = Linear(64, 64)
        self.lin2 = Linear(64, out_channels)

    def forward(self, data):
        x0, edge_index, batch = data.x, data.edge_index, data.batch

        x1 = self.conv1(x0, edge_index).relu()
        x2 = self.conv2(x1, edge_index).relu()
        x3 = self.conv3(x2, edge_index).relu()
        x3 = F.dropout(x3, p = 0.2, training=self.training)
        x = x3
        x = self.pool(x=x, index = batch, edge_index = edge_index)

        x = self.lin1(x).relu()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin2(x)
        return x

class GCN_pooling_Net_large(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = GCNConv(in_channels, 96)
        self.conv2 = GCNConv(96, 64)
        self.conv3 = GCNConv(64, 55)

        self.pool = aggr.GraphMultisetTransformer(
            in_channels=55,
            hidden_channels=128,
            out_channels=55,
            Conv=GCNConv,
            num_nodes=3, #avg_num_nodes,
            pooling_ratio=0.25,
            pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'],
            num_heads=4,
            layer_norm=False,
        )

        self.lin1 = Linear(55, 128)
        self.lin2 = Linear(128, out_channels)

    def forward(self, data):
        x0, edge_index, batch = data.x, data.edge_index, data.batch

        x1 = self.conv1(x0, edge_index).relu()
        x2 = self.conv2(x1, edge_index).relu()
        x3 = self.conv3(x2, edge_index).relu()
        x3 = F.dropout(x3, p = 0.2, training=self.training)
        x = x3
        x = self.pool(x=x, index = batch, edge_index = edge_index)

        x = self.lin1(x).relu()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin2(x)
        return x


class MyData(Data):
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'dx':
            return None
        else:
            return super().__cat_dim__(key, value, *args, **kwargs)

if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.no_state:
        args.save_model_name = 'nostate_' + args.save_model_name
    if args.load_VAE_name is None:
        args.load_VAE_name = args.env_id+'_'
    if args.load_dataset_name is None:
        args.load_dataset_name = args.env_id+'_'

    if args.env_id == 'RoadRunner-v5':
        num_actions = 18
    elif args.env_id == 'Riverraid-v5':
        num_actions = 18
    elif args.env_id == 'SpaceInvaders-v5':
        num_actions = 6
    else:
        raise TypeError('env-id not recognized')

    taskid_digits = 10

    # Read and create dataset
    dataset = torch.load('dataset/'+args.load_dataset_name)
    print ('Dataset loaded')
    if not args.no_state:
        VAE = VanillaVAE(4,num_actions, hidden_dims = [32, 64, 128, 256, 512], out_dim_temp = 3).to(device)
        VAE.load_state_dict(torch.load('saved_models/VAE_'+args.load_VAE_name))
        VAE.eval()
        print ('VAE loaded')

    # add task identifier for Riverraid-v5 and SpaceInvaders-v5
    if args.env_id == 'RoadRunner-v5':
        input_dim = num_actions
        model = GCN_pooling_Net(input_dim, num_actions).to(device)
        if args.no_state:
            fake_state_feature = torch.zeros(num_actions).to(device)
    elif args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
        input_dim = num_actions + taskid_digits
        model = GCN_pooling_Net_large(input_dim, num_actions).to(device)
        if args.no_state:
            fake_state_feature = torch.zeros(num_actions + taskid_digits).to(device)

        ids = None
        for task in range(len(dataset['logits'][0])):
            task_id = (torch.ones((len(dataset['logits']), 1, taskid_digits)) * (task / len(dataset['logits'][0]))).to(device)
            if ids is None:
                ids = task_id
            else:
                ids = torch.cat((ids, task_id), dim=1)
        dataset['logits'] = torch.cat((dataset['logits'], ids), dim=2)
    else:
        raise TypeError('env-id not recognized')

    def build_GNN_dataset():
        length = 9000
        data_list = []
        and_feature = torch.ones(input_dim).to(device)
        if not args.no_state:
            with torch.no_grad():
                obs_hidden = VAE(dataset['obs'][:length])[-1]
                if args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
                    temp = torch.zeros((len(obs_hidden), taskid_digits)).to(device)
                    obs_hidden = torch.cat((obs_hidden, temp), dim=1)
        if args.env_id == 'RoadRunner-v5':
            for i in range(min(length, len(dataset['obs']))):
                if args.no_state:
                    x = torch.stack(
                        (dataset['logits'][i][1], dataset['logits'][i][2], fake_state_feature, and_feature))
                else:
                    x = torch.stack(
                        (dataset['logits'][i][1], dataset['logits'][i][2], obs_hidden[i], and_feature))
                _, y = torch.max(dataset['logits'][i][0:1], 1)
                edge_index = torch.tensor([[0,1,2], [3,3,3]])
                data_list.append(MyData(x=x.detach(), edge_index=edge_index.detach(), y = y.detach()).to(device))
        elif args.env_id == 'Riverraid-v5' or args.env_id == 'SpaceInvaders-v5':
            for i in range(min(length, len(dataset['obs']))):
                for t in range(2,len(dataset['logits'][i])):
                    if args.no_state:
                        x = torch.stack(
                            (dataset['logits'][i][1], dataset['logits'][i][t], fake_state_feature, and_feature))
                    else:
                        x = torch.stack((dataset['logits'][i][1], dataset['logits'][i][t], obs_hidden[i], and_feature))
                    _, y = torch.max(dataset['logits'][i][0:1], 1)
                    edge_index = torch.tensor([[0,1,2], [3,3,3]])
                    data_list.append(MyData(x=x.detach(), edge_index=edge_index.detach(), y = y.detach()).to(device))
        else:
            raise TypeError('env-id not recognized')
        loader = DataLoader(data_list, batch_size=128, shuffle=True)
        return loader

    loader = build_GNN_dataset()
    print ("GNN dataset built")

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0
    print ('Number of data: ', len(loader))
    for epoch in range(1000):
        model.train()
        for data in loader:
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
        if epoch % 1 == 0:
            with torch.no_grad():
                model.eval()
                correct = 0
                total = 0
                for data in loader:
                    out = model(data)
                    _, predicted = torch.max(out.data, 1)
                    total += data.y.size(0)
                    correct += (predicted == data.y).sum().item()
                acc = correct / total
                print(epoch,'Accuracy of the network on the test images: {} %'.format(100 * acc))
                print ('loss: ', loss)
                if acc >= best_acc:
                    torch.save(model.state_dict(), 'saved_models/GNN_' + args.env_id+'_'+args.save_model_name)
                    print('GNN model saved: ', 'saved_models/GNN_' + args.env_id+'_'+args.save_model_name)
                    best_acc = acc
