import os 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.distributions.uniform import Uniform
import pickle
from collections import namedtuple
from itertools import chain, repeat
import numpy as np
from e2cnn import gspaces
from e2cnn import nn

class RCNNV4(torch.nn.Module):

    def __init__(self, n_feats = 48):
        super(RCNNV4, self).__init__()
        from e2cnn import nn
        # the model is equivariant under rotations by 45 (2pi/8) degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        # convolution 3
        in_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.input_type = in_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)

        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)

        x = x.tensor
        return x 
    
class RCNNV4_4(torch.nn.Module):
    def __init__(self, n_feats = 48):
        super(RCNNV4_4, self).__init__()
        from e2cnn import nn
        # the model is equivariant under rotations by 45 (2pi/8) degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.input_type = in_type
        
        # convolution 1
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 2
        in_type = self.block1.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        )

        # convolution 3
        in_type = self.block2.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)

        in_type = self.block4.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        
        in_type = self.block6.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block7 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block8 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool4 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=1)

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)

        x = self.block1(x)   
        x = self.block2(x)
        x = self.pool1(x)

        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)

        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)

        x = self.block7(x)
        x = self.block8(x)
        x = self.pool4(x)

        x = x.tensor
        return x 

