# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from .admm_layers import Flatten
# =============================================================================
# nn examples
# =============================================================================

def mnist_model_cnn(): 
    model = nn.Sequential(
        nn.Conv2d(1, 16, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 32, 4, stride=2, padding=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(32*7*7,100),
        nn.ReLU(),
        nn.Linear(100, 10)
    )
    return model

def mnist_cnn2fc_madry():
    model = nn.Sequential(
        Flatten(),
        nn.Linear(1*28*28, 16*14*14),
        nn.ReLU(),
        nn.Linear(16*14*14, 32*7*7),
        nn.ReLU(),
        nn.Linear(32*7*7,100),
        nn.ReLU(),
        nn.Linear(100, 10)
    )
    return model
    



def mnist_500(): 
    model = nn.Sequential(
        Flatten(),
        nn.Linear(28*28,500),
        nn.ReLU(),
        nn.Linear(500, 10)
    )
    return model



def cifar_model_large(): 
    model = nn.Sequential(
        nn.Conv2d(3, 32, 3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 32, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 64, 4, stride=2, padding=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(64*8*8,512),
        nn.ReLU(),
        nn.Linear(512,512),
        nn.ReLU(),
        nn.Linear(512,10)
    )
    return model

def cifar_model_small():
    model = nn.Sequential(
        nn.Conv2d(3, 16, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 32, 4, stride=2, padding=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(32*8*8,100),
        nn.ReLU(),
        nn.Linear(100, 10)
    )
    return model


def sa_dqn(num_actions = 6):
    input_ch = 1
    width = 1
    model = nn.Sequential(
        nn.Conv2d(input_ch, 32 * width, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32 * width, 64 * width, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64 * width, 64 * width, kernel_size=3, stride=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(3136 * width, 512 * width),
        nn.ReLU(),
        nn.Linear(512 * width, num_actions)
    )

    return model





