# Install E2CNN library using "pip install git+https://github.com/QUVA-Lab/e2cnn@c77faef49fd1bf12ccf538a63cac201a89f16c6"

import torch
import torch.nn as nn
from torch.nn import AdaptiveAvgPool2d
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.transforms import RandomRotation, Pad, Resize, ToTensor, Compose

from e2cnn import gspaces
from e2cnn import nn as enn
from e2cnn.nn import init
from e2cnn.nn import GeometricTensor
from e2cnn.nn import FieldType
from argparse import ArgumentParser
from e2cnn.nn import EquivariantModule
from e2cnn.gspaces import *

import time, os, copy, random, sys, math, pickle
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from typing import List, Tuple, Any, Union
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


class E2CNN_VGG11(torch.nn.Module):
    
    def __init__(self, n_classes=100, k=1, N=8):
        
        super(E2CNN_VGG11, self).__init__()
        self.chnls = [25, 50, 100, 100, 200, 200, 200, 200, 1024, 100]
        self.ksize = [3,2, 3,2, 3,3,2, 3,3,2, 3,3,2]
        self.pd = [1]*8
        self.strd = [2]*5
        self.N = N
        self.r2_act = gspaces.FlipRot2dOnR2(N=N)
        self.input_type = enn.FieldType(self.r2_act, 3*[self.r2_act.trivial_repr])
        
        #ConvBlock-1 16x3x3
        in_type = self.input_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[0]*k)*[self.r2_act.regular_repr])
        self.block1 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[0], padding=self.pd[0], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #MaxPool-1 2x2
        self.pool1 = enn.SequentialModule(
            enn.PointwiseMaxPool(out_type, kernel_size  = self.ksize[1], stride=self.strd[0])
        )

        #ConvBlock-2 32x3x3
        in_type = self.block1.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[1]*k)*[self.r2_act.regular_repr])
        self.block2 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[2], padding=self.pd[1], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )    

        #MaxPool-2 2x2
        self.pool2 = enn.SequentialModule(
            enn.PointwiseMaxPool(out_type, kernel_size  = self.ksize[3], stride=self.strd[1])
        )

        #ConvBlock-3 64x3x3
        in_type = self.block2.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[2]*k)*[self.r2_act.regular_repr])
        self.block3 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[4], padding=self.pd[2], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #ConvBlock-4 64x3x3
        in_type = self.block3.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[3]*k)*[self.r2_act.regular_repr])
        self.block4 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[5], padding=self.pd[3], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #MaxPool-3 2x2
        self.pool3 = enn.SequentialModule(
            enn.PointwiseMaxPool(out_type, kernel_size  = self.ksize[6], stride=self.strd[2])
        )

        #ConvBlock-5 128x3x3
        in_type = self.block4.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[4]*k)*[self.r2_act.regular_repr])
        self.block5 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[7], padding=self.pd[4], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #ConvBlock-6 128x3x3
        in_type = self.block5.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[5]*k)*[self.r2_act.regular_repr])
        self.block6 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[8], padding=self.pd[5], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #MaxPool-4 2x2
        self.pool4 = enn.SequentialModule(
            enn.PointwiseMaxPool(out_type, kernel_size  = self.ksize[9], stride=self.strd[3])
        )

        #ConvBlock-7 128x3x3
        in_type = self.block6.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[6]*k)*[self.r2_act.regular_repr])
        self.block7 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[10], padding=self.pd[6], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        #ConvBlock-8 128x3x3
        in_type = self.block7.out_type
        out_type = enn.FieldType(self.r2_act, int(self.chnls[7]*k)*[self.r2_act.regular_repr])
        self.block8 = enn.SequentialModule(
            enn.R2Conv(in_type, out_type, kernel_size=self.ksize[11], padding=self.pd[7], bias=False),
            enn.InnerBatchNorm(out_type),
            enn.ELU(out_type, inplace=True)
        )

        # #MaxPool-5 2x2
        # self.pool5 = nn.SequentialModule(
        #     nn.PointwiseMaxPool(out_type, kernel_size  = self.ksize[12], stride=self.strd[4])
        # )
        
        self.gpool = enn.GroupPooling(out_type)
        c = self.gpool.out_type.size
        self.pool5=AdaptiveAvgPool2d(1)

        # Fully Connected ##Add dropout
        self.fully_net = nn.Sequential(
            nn.Linear(c, self.chnls[8]),
            nn.BatchNorm1d(self.chnls[8]),
            nn.ELU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(self.chnls[8], self.chnls[9]),
        )
        
    
    def forward(self, input: torch.Tensor):
        x = enn.GeometricTensor(input, self.input_type)
        x = self.block1(x)
        x = self.pool1(x)
        
        x = self.block2(x)
        x = self.pool2(x)
        
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool3(x)

        x = self.block5(x)
        x = self.block6(x)
        x = self.pool4(x)

        x = self.block7(x)
        x = self.block8(x)

        x = self.gpool(x)
        x = x.tensor
        y=x
        x = self.pool5(x)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x,y

class OneCycle(object):
    def __init__(self, nb, max_lr, low_lr, dec_scale):
        self.nb = nb
        self.div = max_lr/low_lr
        self.high_lr = max_lr
        self.iteration = 0
        self.lrs = []
        self.dec_scale = dec_scale
        self.step_len =  int(self.nb / 4)
        
    def calc(self):
        lr = self.calc_lr_cosine()
        self.iteration += 1
        return lr

    def calc_lr_cosine(self):
        if self.iteration ==  0:
            self.lrs.append(self.high_lr/self.div)
            return self.high_lr/self.div
        elif self.iteration == self.nb:
            self.iteration = 0
            self.lrs.append(self.high_lr/self.div)
            old_high_lr = self.high_lr
            old_div = self.div
            self.high_lr = self.high_lr/self.dec_scale
            self.div = self.div/self.dec_scale
            return old_high_lr/old_div
        elif self.iteration > self.step_len:
            ratio = (self.iteration -self.step_len)/(self.nb - self.step_len)
            lr = (self.high_lr/self.div) + 0.5 * (self.high_lr - self.high_lr/self.div) * (1 + math.cos(math.pi * ratio))
        else :
            ratio = self.iteration/self.step_len
            lr = self.high_lr - 0.5 * (self.high_lr - self.high_lr/self.div) * (1 + math.cos(math.pi * ratio))
        self.lrs.append(lr)
        return lr


params = {
    'weight_decay_rate' : 1e-7,
    'cyc_size' : 70,
    'max_lr' : 5e-3,
    'low_lr' : 1e-5,
    'dec_scale' : 1,
    'batch_size' : 64,
    'mode' : 'train',
    'tot_epochs' : 90
}

data_transforms = {'train' : transforms.Compose([transforms.ToTensor()]),
                    'val' : transforms.Compose([transforms.ToTensor(),]) }
data_dir = 'ImageNet-rot-ref-masked/'
image_datasets =  {'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']),
                  'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])}
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=params['batch_size'], shuffle=True, num_workers=2),
                'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=params['batch_size'], shuffle=False, num_workers=2)}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

model = E2CNN_VGG11().to(device)
classification_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['low_lr'], weight_decay=params['weight_decay_rate'])
onecycle = OneCycle(math.ceil(dataset_sizes['train']/params['batch_size'])*params['cyc_size'], params['max_lr'], params['low_lr'], params['dec_scale'])

best_val_acc = 0
for epoch in range(params['tot_epochs']):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for i,(inputs, labels) in enumerate(dataloaders['train']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        tbs = labels.shape[0]
        optimizer.zero_grad()
        if epoch<params['cyc_size']:    
            lr = onecycle.calc()
            for g in optimizer.param_groups:
                    g['lr'] = lr
        outputs,_ = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = classification_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
           
    train_loss = running_loss / dataset_sizes['train']
    train_acc = running_corrects.double() / dataset_sizes['train']
    
    ### Validation
    model.eval()
    with torch.no_grad():
        val_running_loss1 = 0.0
        val_running_corrects1 = 0
        for i,(inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            tbs = inputs.shape[0]
            ar1 = random.choices(range(0,4),k=tbs)
            ar2 = random.choices(range(0,3),k=tbs)
            for i in range(tbs):
                augmented_input = torch.rot90(inputs[i], ar1[i], [-2, -1])
                if ar2[i]==1:
                    augmented_input = torch.flip(augmented_input, (-2,))
                elif ar2[i]==2:
                    augmented_input = torch.flip(augmented_input, (-1,))
                inputs = torch.cat((inputs,augmented_input.unsqueeze(0)),0)
            _, preds = torch.max(outputs, 1)
            labels = torch.cat((labels,labels))
            loss1 = classification_loss(outputs,labels)
            val_running_loss1 += loss1.item() * inputs.size(0)
            val_running_corrects1 += torch.sum(preds == labels.data)
        val_loss = val_running_loss1 / (dataset_sizes['val']*2)
        val_acc = val_running_corrects1.double() / (dataset_sizes['val']*2)
        
        print(f"Epoch {epoch}: Train Acc {train_acc}, Train Loss {train_loss} | ; | Val Acc {val_acc}, Val Loss {val_loss}")
        torch.save(model.state_dict(),"last_model.ckpt")
        torch.save(optimizer.state_dict(),"last_adam_setting.ckpt")
        if val_acc>best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(),"best_model_weights.ckpt")
        with open("cyclelr_object.pkl","wb") as f:
            pickle.dump(onecycle,f)
