import torch
from torch import nn
import torch.nn.functional as F
from geoopt import Stiefel, ManifoldParameter, Euclidean
from manifolds import EuclideanMod
import math
import numpy as np
import argparse

import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from geoopt.optim import RiemannianSGD


# 1. add more fc layers before the last layer
# 2. carefully consider init
# 3. use less conv filters and less orthogonal filters


def MiniimageNetFeats(hidden_size):
    def conv_layer(ic, oc):
        return nn.Sequential(
            nn.Conv2d(ic, oc, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.BatchNorm2d(oc, momentum=1., affine=True,
                           track_running_stats=False
                           )
        )

    net = nn.Sequential(
        conv_layer(3, hidden_size),
        conv_layer(hidden_size, hidden_size),
        conv_layer(hidden_size, hidden_size),
        conv_layer(hidden_size, hidden_size),
        nn.Flatten())

    #initialize(net)
    return net



class CNN(nn.Module):
    def __init__(self, hidden_size, ic=3, ks=3, padding=1):
        super().__init__()

        self.ic = ic
        self.hidden_size = hidden_size
        self.ks = ks
        self.pad = padding
        self.stride = 1

        self.stiefel = Stiefel(canonical=False)

        # self.conv0_kernel = ManifoldParameter(torch.Tensor(hidden_size,ic,ks,ks).normal_(0,2/math.sqrt(ic)), manifold=Euclidean(ndim=4))
        self.conv0_kernel = ManifoldParameter(self.stiefel.random(ic*ks*ks, hidden_size), manifold=self.stiefel)

        self.conv1_kernel = ManifoldParameter(self.stiefel.random(hidden_size*ks*ks,hidden_size),
                                              manifold=self.stiefel)

        # self.fc_w1 = ManifoldParameter(nn.init.xavier_normal_(torch.Tensor()), manifold=Euclidean(ndim=2))
        # self.conv2_kernel = ManifoldParameter(self.stiefel.random(hidden_size*ks*ks,hidden_size),
        #                                       manifold = self.stiefel)
        # self.conv3_kernel = ManifoldParameter(self.stiefel.random(hidden_size*ks*ks,hidden_size),
        #                                       manifold=self.stiefel)

        self.bn0_w = ManifoldParameter(torch.ones(hidden_size), manifold=Euclidean(ndim=1)) # nn.BatchNorm2d(hidden_size, momentum=0.5, affine=True, track_running_stats=False)
        self.bn0_b = ManifoldParameter(torch.zeros(hidden_size), manifold=Euclidean(ndim=1))
        self.bn1_w = ManifoldParameter(torch.ones(hidden_size), manifold=Euclidean(ndim=1))
        self.bn1_b = ManifoldParameter(torch.zeros(hidden_size), manifold=Euclidean(ndim=1))
        # self.bn2_w = ManifoldParameter(torch.ones(hidden_size), manifold=Euclidean(ndim=1))
        # self.bn2_b = ManifoldParameter(torch.zeros(hidden_size), manifold=Euclidean(ndim=1))
        # self.bn3_w = ManifoldParameter(torch.ones(hidden_size), manifold=Euclidean(ndim=1))
        # self.bn3_b = ManifoldParameter(torch.zeros(hidden_size), manifold=Euclidean(ndim=1))

    # def compute_outdim(self, insize):
    #     outdim = (insize - self.ks + 2 * self.pad)/self.stride + 1
    #     outdim = (outdim - self.ks + 2 * self.pad) / self.stride + 1
    #     outdim = (outdim - self.ks + 2 * self.pad) / self.stride + 1
    #     return (outdim - self.ks + 2 * self.pad) / self.stride + 1


    def conv_layer(self, x, conv_param, bn_w=None, bn_b=None):
        x = F.relu(F.conv2d(x, conv_param,padding=self.pad), inplace=True)
        x = F.batch_norm(x, running_mean=torch.zeros(self.hidden_size).to(device),
                         running_var=torch.ones(self.hidden_size).to(device),
                         weight=bn_w, bias=bn_b, training=self.training,
                         momentum=0.5, eps=1e-5)
        return x

    def forward(self, x):

        # x = self.conv_layer(x, self.conv0_kernel.transpose(-1,-2).view(self.hidden_size,self.ic,self.ks, self.ks))
        # x = self.conv_layer(x, self.conv1_kernel.transpose(-1,-2).view(self.hidden_size,self.hidden_size,self.ks, self.ks))
        x = self.conv_layer(x, self.conv0_kernel.transpose(-1,-2).view(self.hidden_size,self.hidden_size,self.ks, self.ks),
                            self.bn0_w, self.bn0_b)
        x = self.conv_layer(x, self.conv1_kernel.transpose(-1,-2).view(self.hidden_size,self.hidden_size,self.ks, self.ks),
                            self.bn1_w, self.bn1_b)
        return x.view(x.size(0), -1)


class FC(nn.Module):
    def __init__(self, input_size, num_class):
        super().__init__()
        self.weight = ManifoldParameter(torch.randn(input_size, num_class), manifold=EuclideanMod(ndim=2))
        self.bias = ManifoldParameter(torch.zeros(num_class), manifold=EuclideanMod(ndim=1))

    def forward(self, x):
        return x @ self.weight + self.bias



class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.stiefel = Stiefel(canonical=False)
        # self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv1_kernel = ManifoldParameter(self.stiefel.random(3 * 5 * 5, 6),
                                              manifold=self.stiefel)
        self.pool = nn.MaxPool2d(2, 2)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv2_kernel = ManifoldParameter(self.stiefel.random(6 * 5 * 5, 16),
                                              manifold=self.stiefel)

        self.fc1_w = ManifoldParameter(torch.Tensor(16 * 5 * 5, 120).uniform_(-0.001, 0.001), manifold=Euclidean(ndim=2))
        self.fc1_b = ManifoldParameter(torch.Tensor(120).uniform_(-0.001,0.001), manifold=Euclidean(ndim=1))

        # self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        # self.bn0_w = ManifoldParameter(torch.ones(6), manifold=Euclidean(ndim=1))
        # self.bn0_b = ManifoldParameter(torch.zeros(6), manifold=Euclidean(ndim=1))
        # self.bn1_w = ManifoldParameter(torch.ones(hidden_size), manifold=Euclidean(ndim=1))
        # self.bn1_b = ManifoldParameter(torch.zeros(hidden_size), manifold=Euclidean(ndim=1))
        self.bn = nn.BatchNorm2d(6)

    def forward(self, x):
        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(F.conv2d(x, self.conv1_kernel.T.view(6,3,5,5))))
        x = self.bn(x)
        x = self.pool(F.relu(F.conv2d(x, self.conv2_kernel.T.view(16,6,5,5))))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # x = F.relu(self.fc1(x))
        x = F.relu(x @ self.fc1_w + self.fc1_b)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x








if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--eta_x', type=float, default=0.01)
    parser.add_argument('--eta_y', type=float, default=0.01)
    parser.add_argument('--lower_iter', type=int, default=50)
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--hygrad_opt', type=str, default='ns', choices=['hinv', 'cg', 'ns', 'ad'])
    parser.add_argument('--ns_gamma', type=float, default=0.01)
    parser.add_argument('--ns_iter', type=int, default=50)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(args.seed)
    print(device)
    torch.random.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True



    # test on cifar10 classification


    insize = 32
    ks = 3
    hid_s = 8
    num_class = 10
    epochs = 20
    batch_size = 64

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # model = CNN(hid_s, ks=ks).to(device)
    # outdim = int(model.compute_outdim(insize))
    # classifier = FC(outdim*outdim*hid_s, num_class=num_class).to(device)
    # criterion = nn.CrossEntropyLoss()
    # optimizer = RiemannianSGD(list(model.parameters()) + list(classifier.parameters()), lr=0.001, momentum=0.9)

    model = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = RiemannianSGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(epochs):  # loop over the dataset multiple times

        train_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            # outputs = model(inputs)
            # outputs = classifier(outputs)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            train_loss += loss.item()

        train_loss = train_loss / (i+1)

        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_acc = 0
            total = 0
            for i, data in enumerate(testloader):
                inputs, labels = data[0].to(device), data[1].to(device)
                # outputs = model(inputs)
                # outputs = classifier(outputs)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, pred = torch.max(outputs, 1)
                total += labels.size(0)
                val_acc += (pred == labels).sum().item()
                val_loss += loss.item()
        print(f'[{epoch}] loss: {train_loss:.3f}, val loss: {val_loss/(i+1):.3f}, val acc: {val_acc/total:.4f}')
        train_loss = 0.0

        model.train()

    print('Finished Training')