# # python mnist.py --batch-size 50 --epochs 50 --lr 1e-3 --gamma 1. --seed 2023 --save-model --log-interval 400 --no-test
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts
# from typing import Optional
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import time
import csv
import pandas
import time
import statistics
import lion_pytorch 
# import ranger

eps = 1e-8

# def _spec_gap(x):
#     return x - eps + 2 * eps * (x > 0)

def ReLU(x):
    return F.relu(x)

def Id(x):
    return x

def SiLU(x):
#     return x / (1 + torch.exp(-x))
    return x * torch.sigmoid(x)

def CoReLU(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = (x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)).clamp(0,1)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoSiLU(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)
    scalar = torch.sigmoid(scalar)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoReLU_neg(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = (x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)).clamp(-1,1)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoSiLU_neg(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)
    scalar = torch.erf(scalar)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoReLU_inf(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = (x1 / (torch.max(x_1,dim=1,keepdim=True)[0] + eps)).clamp(0,1)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoSiLU_inf(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = (x1 / (torch.max(x_1,dim=1,keepdim=True)[0] + eps))
    scalar = torch.sigmoid(scalar)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoReLU_1(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = (x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)).clamp(0,1)
    return torch.cat([x1,scalar*x_1],dim=1)

def CoSiLU_1(x):
    x1 = x[:,:1]
    x_1 = x[:,1:]
    scalar = x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)
    scalar = torch.sigmoid(scalar)
    return torch.cat([x1,scalar*x_1],dim=1)

nonlinearities = { 
                  'CoReLU':CoReLU, 
#                   'CoSiLU':CoSiLU, 
#                   'CoReLU_1':CoReLU_1, 
#                   'CoSiLU_1':CoSiLU_1, 
#                   'CoReLU_\infty':CoReLU_inf, 
#                   'CoSiLU_\infty':CoSiLU_inf, 
#                   'CoReLU_\pm':CoReLU_neg, 
#                   'CoSiLU_\pm':CoSiLU_neg,
                  'ReLU':ReLU, 
#                   'SiLU':SiLU, 
#                   'Id':Id, 
                  }
    
class MLP(nn.Module):
    def __init__(self, nonlinearity = 'ReLU'):
        super(MLP, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.fc1 = nn.Linear(9216, 128)
        C = 128
        self.fc1 = nn.Linear(1024, C)
        self.fc2 = nn.Linear(C, C)
#         self.fc3 = nn.Linear(C, C)
#         self.fc4 = nn.Linear(C, C)
#         self.fc5 = nn.Linear(C, C)
#         self.fc6 = nn.Linear(C, C)
#         self.fc7 = nn.Linear(C, C)
#         self.fc8 = nn.Linear(C, C)
#         self.fc9 = nn.Linear(C, C)
        self.fc10= nn.Linear(C, 10)
        self.nonlinearity = nonlinearities[nonlinearity]

    def forward(self, x):
#         x = self.conv1(x)
#         x = nonlinearity(x)
#         x = self.conv2(x)
#         x = nonlinearity(x)
#         x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.nonlinearity(x)
        x = self.fc2(x)
        x = self.nonlinearity(x)
#         x = self.fc3(x)
#         x = self.nonlinearity(x)
#         x = self.fc4(x)
#         x = self.nonlinearity(x)
#         x = self.fc5(x)
#         x = self.nonlinearity(x)
#         x = self.fc6(x)
#         x = self.nonlinearity(x)
#         x = self.fc7(x)
#         x = self.nonlinearity(x)
#         x = self.fc8(x)
#         x = self.nonlinearity(x)
#         x = self.fc9(x)
#         x = self.nonlinearity(x)
        x = self.fc10(x)
        output = F.log_softmax(x, dim=1)
        return output
    
class CNN(nn.Module): #deep unalignable
    def __init__(self, nonlinearity = 'ReLU'):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 128, 3, 1)
        self.conv2 = nn.Conv2d(128, 128, 3, 1)
        self.conv3 = nn.Conv2d(128, 128, 3, 1)
        self.conv4 = nn.Conv2d(128, 128, 1)
        self.conv5 = nn.Conv2d(128, 10, 1)
        self.nonlinearity = nonlinearities[nonlinearity]
#         self.norm1 = nn.InstanceNorm2d(128)
#         self.norm2 = nn.InstanceNorm2d(128)
#         self.norm3 = nn.InstanceNorm2d(128)
#         self.norm4 = nn.InstanceNorm2d(128)
    def forward(self, x):
        x = self.conv1(x)
#         x = self.norm1(x)
        x = self.nonlinearity(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
#         x = self.norm2(x)
        x = self.nonlinearity(x)
        x = self.conv3(x)
        x = F.max_pool2d(x, 2)
#         x = self.norm3(x)
        x = self.nonlinearity(x)
        x = self.conv4(x)
        x = F.max_pool2d(x, 2)
#         x = self.norm4(x)
        x = self.nonlinearity(x)
        x = self.conv5(x)
        x = F.max_pool2d(x, 2)
        x = F.log_softmax(x, dim=1).squeeze()
        return x

class CNN(nn.Module):
    def __init__(self, nonlinearity = 'ReLU'):
        super(CNN, self).__init__()
        C = 128
        self.conv1 = nn.Conv2d(1, C, 7,1,3,1)#32
        self.conv2 = nn.Conv2d(C, C, 5,1,2,1)#16
        self.conv3 = nn.Conv2d(C, C, 3,1,1,1)#8
        self.conv4 = nn.Conv2d(C, C, 1)#4
        self.conv5 = nn.Conv2d(C, 10, 1)#2
#         self.norm1 = nn.BatchNorm2d(C,affine=False)
#         self.norm2 = nn.BatchNorm2d(C,affine=False)
#         self.norm3 = nn.BatchNorm2d(C,affine=False)
#         self.norm4 = nn.BatchNorm2d(C,affine=False)
        self.norm1 = nn.InstanceNorm2d(C,affine=False)
        self.norm2 = nn.InstanceNorm2d(C,affine=False)
        self.norm3 = nn.InstanceNorm2d(C,affine=False)
        self.norm4 = nn.InstanceNorm2d(C,affine=False)
        self.nonlinearity = nonlinearities[nonlinearity]
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = self.norm1(x)
        x = self.nonlinearity(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.norm2(x)
        x = self.nonlinearity(x)
        x = self.conv3(x)
        x = F.max_pool2d(x, 2)
        x = self.norm3(x)
        x = self.nonlinearity(x)
        x = self.conv4(x)
        x = F.max_pool2d(x, 2)
        x = self.norm4(x)
        x = self.nonlinearity(x)
        x = self.conv5(x)
        x = F.max_pool2d(x, 2)
        x = F.log_softmax(x, dim=1)
        return x.squeeze()

def train(args, model, device, train_loader, optimizer, epoch, prev_losses):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()        
#         nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

        optimizer.step()
        prev_losses.append(loss.item())
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break
#     return prev_losses

@torch.no_grad()
def test(model, device, test_loader):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        acc))
    return test_loss


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--network', type=str, default='cnn', 
                        help='cnn or mlp (default: cnn)')
    parser.add_argument('--optimizer', type=str, default='adam', 
                        help='optimizer (default: adam)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.6, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
#     parser.add_argument('--T', type=int, default=20, metavar='M',
#                         help='Learning rate step CosAnneal Restart (default: 20)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-test', action='store_true', default=False,
                        help='no testing each epoch')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=23, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=500, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', type=bool, default=True,
                        help='save the current model')
    parser.add_argument('--resume', type=str, default="",
                        help='load the model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()


    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 16,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Resize(32,antialias=True)
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    
    fig = go.Figure()
    data = {}
    for nonlinearity in nonlinearities:
        print("Training {}:".format(nonlinearity))
        torch.manual_seed(args.seed) # reproducibility
        Model = CNN if args.network == "cnn" else MLP
        model = Model(nonlinearity=nonlinearity).to(device)
        if args.resume:
            model.load_state_dict(torch.load(args.resume))
        elif args.optimizer == "adadelta":
            optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
        elif args.optimizer == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optimizer == "adamw":
            optimizer = optim.AdamW(model.parameters(), lr=args.lr)
        elif args.optimizer == "ranger":
            optimizer = ranger.Ranger(model.parameters(), lr=args.lr)
        elif args.optimizer == "sgd":
            optimizer = optim.SGD(model.parameters(), lr=args.lr,weight_decay=1e-5)
        elif args.optimizer == "lion":
            optimizer = lion_pytorch.Lion(model.parameters(), lr=args.lr, weight_decay=1e-2)
        else:
            raise NotImplementedError
        
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
#         scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.T, eta_min=1e-8)
        ts = []
        prev_tloss = -1
        prev_losses = []
        for epoch in range(1, args.epochs + 1):
            tic = time.time()
            train(args, model, device, train_loader, optimizer, epoch, prev_losses)
            toc = time.time()
            elapsed = toc - tic
            ts.append(elapsed)
#             print("Training time for an epoch: {}".format(elapsed))

            scheduler.step()
            if epoch == 1:
                idx = torch.arange(0,len(prev_losses))
                curve = torch.tensor(prev_losses)
                data[nonlinearity] = prev_losses
                fig.add_trace(go.Scatter(x=idx, y=curve, 
                            mode='lines',opacity=0.6,
                            name='${}$'.format(nonlinearity)
                            ))
            if not args.no_test:
                model.eval()
                tloss = test(model, device, test_loader)
                model.train()
                prev_tloss = tloss
#             print("LR: {:.4f}\n".format(scheduler.get_lr()))
#             if abs(prev_tloss - tloss) < 1e-4: # rough stopping criterion on test loss
#                 print("Terminates at epoch {}".format(epoch))
#                 break

        if len(ts) > 1:
            print("Training time for an epoch: {:.4f} ± {:.4f} s\n".format(statistics.mean(ts),statistics.stdev(ts)))
        
        if args.no_test:
            model.eval()
            tloss = test(model, device, test_loader)

        if args.save_model:
            torch.save(model.state_dict(), "ckpt/mnist_{}_{}_{}_seed_{}.pt".format(args.network,nonlinearity,args.optimizer,args.seed)) 
        
    fig.update_xaxes(title_text='Step')
    fig.update_yaxes(title_text='Loss',range=[0,2.5])
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99
    ))
    name = 'fig/{}_lr_{}_bs_{}_gamma_{}.pdf'.format(type(optimizer).__name__, args.lr, args.batch_size, args.gamma)
    fig.write_image(name)
    time.sleep(2)
    fig.write_image(name)
    df = pandas.DataFrame(data)
    df.to_csv("./loss.csv", sep=',',index=False)
    
#     return model


if __name__ == '__main__':
    main()
