# Install E2CNN library using "pip install git+https://github.com/QUVA-Lab/e2cnn@c77faef49fd1bf12ccf538a63cac201a89f16c6"

import torch
import torch.nn as nn
from torch.nn import AdaptiveAvgPool2d
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.transforms import RandomRotation, Pad, Resize, ToTensor, Compose

from e2cnn import gspaces
from e2cnn import nn as enn
from e2cnn.nn import init
from e2cnn.nn import GeometricTensor
from e2cnn.nn import FieldType
from argparse import ArgumentParser
from e2cnn.nn import EquivariantModule
from e2cnn.gspaces import *

import time, os, copy, random, sys, math, pickle
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from typing import List, Tuple, Any, Union
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=3,
            dilation=1, bias=False):
    """7x7 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 7,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      )


def conv5x5(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=2,
            dilation=1, bias=False):
    """5x5 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 5,
                      stride=stride,
                      padding=padding, 
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      )


def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1,
            dilation=1, bias=False):
    """3x3 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 3,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      )


def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
            dilation=1, bias=False):
    """1x1 convolution with padding"""
    return enn.R2Conv(in_type, out_type, 1,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=bias,
                      sigma=None,
                      frequencies_cutoff=lambda r: 3*r,
                      )


def regular_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True):
    """ build a regular feature map with the specified number of channels"""
    assert gspace.fibergroup.order() > 0
    
    N = gspace.fibergroup.order()
    
    if fixparams:
        planes *= math.sqrt(N)
    
    planes = planes / N
    planes = int(planes)
    
    return enn.FieldType(gspace, [gspace.regular_repr] * planes)

def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True):
    """ build a trivial feature map with the specified number of channels"""
    
    if fixparams:
        planes *= math.sqrt(gspace.fibergroup.order())
        
    planes = int(planes)
    return enn.FieldType(gspace, [gspace.trivial_repr] * planes)


FIELD_TYPE = {
    "trivial": trivial_feature_type,
    "regular": regular_feature_type,
}

class WideBasic(enn.EquivariantModule):
    
    def __init__(self,
                 in_type: enn.FieldType,
                 inner_type: enn.FieldType,
                 stride: int = 1,
                 out_type: enn.FieldType = None,
                 ):
        super(WideBasic, self).__init__()
        
        if out_type is None:
            out_type = in_type
        
        self.in_type = in_type
        inner_type = inner_type
        self.out_type = out_type
        

        conv = conv3x3
        self.bn1 = enn.InnerBatchNorm(self.in_type)
        self.relu1 = enn.ReLU(self.in_type, inplace=True)
        self.conv1 = conv(self.in_type, inner_type,stride=stride)
        
        self.bn2 = enn.InnerBatchNorm(inner_type)
        self.relu2 = enn.ReLU(inner_type, inplace=True)
         
        self.conv2 = conv(inner_type, self.out_type)
        
        self.shortcut = None
        if stride != 1 or self.in_type != self.out_type:
            self.shortcut = conv1x1(self.in_type, self.out_type, stride=stride, bias=False)
    
    def forward(self, x):
        x_n = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(x_n)))
        out = self.conv2(out)
        
        if self.shortcut is not None:
            out += self.shortcut(x_n)
        else:
            out += x

        return out
    
    def evaluate_output_shape(self, input_shape: Tuple):
        assert len(input_shape) == 4
        assert input_shape[1] == self.in_type.size
        if self.shortcut is not None:
            return self.shortcut.evaluate_output_shape(input_shape)
        else:
            return input_shape

class ResNet18(torch.nn.Module):
    def __init__(self, num_classes=100,
                 N: int = 8,
                 r: int = 0,
                 f: bool = False,
                 deltaorth: bool = False,
                 fixparams: bool = True,
                 initial_stride: int = 1,
                 ):
        super(ResNet18, self).__init__()

        nStages = [24, 24, 48, 96, 192]
        self._fixparams = fixparams
        self._layer = 0
        self._N = N
        self.gspace = gspaces.Rot2dOnR2(N)
        assert r in [0, 1, 2, 3]
        self._r = r
        
        r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        self.in_type = r1
        r2 = enn.FieldType(self.gspace, int(nStages[0])*[self.gspace.regular_repr])
        self._in_type = r2
        self.conv1 = conv3x3(r1, r2,stride=1, padding=1, bias=False)
        self.layer1 = self._wide_layer(WideBasic, nStages[1], 2, stride=1)
        self.layer2 = self._wide_layer(WideBasic, nStages[2], 2, stride=2)
        self.layer3 = self._wide_layer(WideBasic, nStages[3], 2, stride=2)
        self.layer4 = self._wide_layer(WideBasic, nStages[4], 2, stride=2)
        self.bn = enn.InnerBatchNorm(self.layer4.out_type, momentum=0.9)
        self.relu = enn.ReLU(self.bn.out_type, inplace=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.gpool = enn.GroupPooling(self.bn.out_type)
        c = self.gpool.out_type.size
        self.linear = torch.nn.Linear(c, num_classes)

        for name, module in self.named_modules():
            if isinstance(module, enn.R2Conv):
                if deltaorth:
                    init.deltaorthonormal_init(module.weights, module.basisexpansion)
            elif isinstance(module, torch.nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
            elif isinstance(module, torch.nn.Linear):
                module.bias.data.zero_()
             
    

    
    def _wide_layer(self, block, planes: int, num_blocks: int, stride: int,
                    totrivial: bool = False
                    ) -> enn.SequentialModule:
        self._layer += 1
        print("start building", self._layer)
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        inner_type =enn.FieldType(self.gspace, int(planes)*[self.gspace.regular_repr])
        for b, stride in enumerate(strides):
            out_f=enn.FieldType(self.gspace, int(planes)*[self.gspace.regular_repr])
            layers.append(block(self._in_type, inner_type, stride, out_type=out_f))
            self._in_type = out_f
        print("layer", self._layer, "built")
        return enn.SequentialModule(*layers)
    
    def features(self, x):
        x = enn.GeometricTensor(x, self.in_type)
        out = self.conv1(x)
        x1 = self.layer1(out)
        x2 = self.layer2(self.restrict1(x1))
        x3 = self.layer3(self.restrict2(x2))
        return x1, x2, x3
    
    def forward(self, x):
        x = enn.GeometricTensor(x, self.in_type)
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.bn(out)
        out = self.relu(out)
        out = self.gpool(out)
        y=out
        out = out.tensor
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out,y

class OneCycle(object):
    def __init__(self, nb, max_lr, low_lr, dec_scale):
        self.nb = nb
        self.div = max_lr/low_lr
        self.high_lr = max_lr
        self.iteration = 0
        self.lrs = []
        self.dec_scale = dec_scale
        self.step_len =  int(self.nb / 4)
        
    def calc(self):
        lr = self.calc_lr_cosine()
        self.iteration += 1
        return lr

    def calc_lr_cosine(self):
        if self.iteration ==  0:
            self.lrs.append(self.high_lr/self.div)
            return self.high_lr/self.div
        elif self.iteration == self.nb:
            self.iteration = 0
            self.lrs.append(self.high_lr/self.div)
            old_high_lr = self.high_lr
            old_div = self.div
            self.high_lr = self.high_lr/self.dec_scale
            self.div = self.div/self.dec_scale
            return old_high_lr/old_div
        elif self.iteration > self.step_len:
            ratio = (self.iteration -self.step_len)/(self.nb - self.step_len)
            lr = (self.high_lr/self.div) + 0.5 * (self.high_lr - self.high_lr/self.div) * (1 + math.cos(math.pi * ratio))
        else :
            ratio = self.iteration/self.step_len
            lr = self.high_lr - 0.5 * (self.high_lr - self.high_lr/self.div) * (1 + math.cos(math.pi * ratio))
        self.lrs.append(lr)
        return lr

params = {
    'weight_decay_rate' : 1e-7,
    'cyc_size' : 70,
    'max_lr' : 5e-3,
    'low_lr' : 1e-5,
    'dec_scale' : 1,
    'batch_size' : 64,
    'mode' : 'train',
    'tot_epochs' : 90
}

resize1 = Resize(65)
data_transforms = {'train' : transforms.Compose([resize1,transforms.ToTensor()]),
                    'val' : transforms.Compose([resize1,transforms.ToTensor(),]) }
data_dir = 'ImageNet-rot-masked/'
image_datasets =  {'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']),
                  'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])}
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=params['batch_size'], shuffle=True, num_workers=2),
                'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=params['batch_size'], shuffle=False, num_workers=2)}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

model = ResNet18().to(device)
classification_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['low_lr'], weight_decay=params['weight_decay_rate'])
onecycle = OneCycle(math.ceil(dataset_sizes['train']/params['batch_size'])*params['cyc_size'], params['max_lr'], params['low_lr'], params['dec_scale'])

best_val_acc = 0
all_rotations = [(RandomRotation((x,x)),RandomRotation((-x,-x))) for x in range(45,360,45)]
for epoch in range(params['tot_epochs']):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for i,(inputs, labels) in enumerate(dataloaders['train']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        tbs = labels.shape[0]
        optimizer.zero_grad()
        if epoch<params['cyc_size']:    
            lr = onecycle.calc()
            for g in optimizer.param_groups:
                    g['lr'] = lr
        outputs,_ = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = classification_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
           
    train_loss = running_loss / dataset_sizes['train']
    train_acc = running_corrects.double() / dataset_sizes['train']
    
    ### Validation
    model.eval()
    with torch.no_grad():
        val_running_loss1 = 0.0
        val_running_corrects1 = 0
        for i,(inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            tbs = inputs.shape[0]
            ar = random.choices(all_rotations,k=tbs)
            for i in range(tbs):
                inputs = torch.cat((inputs,(ar[i][0](inputs[i])).unsqueeze(0)),0)
            outputs, temparr = model(inputs)
            _, preds = torch.max(outputs, 1)
            labels = torch.cat((labels,labels))
            loss1 = classification_loss(outputs,labels)
            val_running_loss1 += loss1.item() * inputs.size(0)
            val_running_corrects1 += torch.sum(preds == labels.data)
        val_loss = val_running_loss1 / (dataset_sizes['val']*2)
        val_acc = val_running_corrects1.double() / (dataset_sizes['val']*2)
        
        print(f"Epoch {epoch}: Train Acc {train_acc}, Train Loss {train_loss} | ; | Val Acc {val_acc}, Val Loss {val_loss}")
        torch.save(model.state_dict(),"last_model.ckpt")
        torch.save(optimizer.state_dict(),"last_adam_setting.ckpt")
        if val_acc>best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(),"best_model_weights.ckpt")
        with open("cyclelr_object.pkl","wb") as f:
            pickle.dump(onecycle,f)
