import torch

try:
    from e2cnn import gspaces  # 1
    from e2cnn import nn as enn
except:
    print('e2cnn not available')
from torch import nn
from torch.nn import functional as f
from torch.nn.functional import cross_entropy
from generalization_study.coordconv import CoordConv2d
from generalization_study.weakly_supervised_models import PCL, \
    BetaVAE, SlowVAE, AdaGVAE
from generalization_study.model_big_transfer import \
    get_weights, ResNetV2


def get_model(model_name, number_classes, number_channels, number_latents,
              args, dataset=None):
    if model_name == 'vanilla':
        model = VanillaCNN(number_classes=number_classes,
                           number_channels=number_channels)
    elif model_name == 'deeper_cnn':
        model = DeeperCNN(number_classes=number_classes,
                           number_channels=number_channels)

    elif model_name == 'mlp':
        model = MLP(number_classes=number_classes,
                    number_channels=number_channels)
    elif model_name == 'transformer':
        model = SpatialTransformer(number_classes=number_classes,
                                   number_channels=number_channels)
    elif model_name == 'coordconv':
        model = CoordConvNet(number_classes=number_classes,
                             number_channels=number_channels)
    elif model_name == 'rotation':
        model = RotationInvariantCNN(
            number_classes=number_classes,
            number_channels=number_channels,
            number_rotations=args.number_rotations,
            feature_reduce_factor=args.feature_reduce_factor)
    elif model_name == 'implicit':
        model = ImplicitMLP(number_classes=number_classes,
                            number_channels=number_channels,
                            image_dimension=dataset.data.shape[-1],
                            z_dimensions=100)
    elif model_name == 'densenet':
        from torchvision.models import densenet121
        import os
        os.environ['TORCH_HOME'] = '/home/anonymous/lanonymous/.torch/'
        model = densenet121(pretrained=args.pretrained, progress=False)
        # re-init and adapt last layer
        model.classifier = torch.nn.Linear(model.classifier.in_features,
                                           number_classes)
        model = ToRGBWrapper(model)
    elif model_name == 'big_transfer_rn50':
        # You could use other variants, such as R101x3 or R152x4 here,
        # but it is not advisable in a colab.
        model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1,
                         head_size=number_classes, zero_head=True)
        if args.pretrained:
            weights = get_weights('BiT-M-R50x1')
            model.load_from(weights)
        if args.only_train_last_layer:
            for p in list(model.parameters())[:-2]:
                p.requires_grad = False
    elif model_name == 'big_transfer_rn101':
        # You could use other variants, such as R101x3 or R152x4 here,
        # but it is not advisable in a colab.
        model = ResNetV2(ResNetV2.BLOCK_UNITS['r101'], width_factor=3,
                         head_size=number_classes, zero_head=True)
        print('data parallel')
        if args.pretrained:
            weights = get_weights('BiT-M-R101x3')
            model.load_from(weights)
        if args.only_train_last_layer:
            for p in list(model.parameters())[:-2]:
                p.requires_grad = False
        # model = torch.nn.DataParallel(model)
# unsupervised models
    elif model_name == 'pcl':
        model = PCL(number_latents=number_latents,
                    number_channels=number_channels)
    elif model_name == 'betavae':
        model = BetaVAE(number_latents=number_latents,
                        number_channels=number_channels,
                        beta=args.vae_beta)
    elif model_name == 'slowvae':
        model = SlowVAE(number_latents=number_latents,
                        number_channels=number_channels,
                        gamma=args.slowvae_gamma, beta=args.vae_beta,
                        rate_prior=args.slowvae_rate)
    elif model_name == 'adagvae':
        model = AdaGVAE(number_latents=number_latents,
                        number_channels=number_channels,
                        beta=args.vae_beta)
    else:
        raise Exception(f'Model {args.model} is not defined')
    return model


class ToRGBWrapper(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x: torch.tensor):
        if x.shape[1] == 1:
            x = torch.cat([x, x, x], dim=1)
        return self.model(x)


class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)


class MLP(nn.Module):
    def __init__(self, number_classes: int, number_channels: int):
        """
        Standard ConvNet Architecture similar to the Locatello Disentanglement
        Library models.
        Args:
            number_classes: number of classes in the dataset
            number_channels: number channels of the input image
        """
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(64 * 64 * number_channels, 90),
            nn.ReLU(True),
            nn.Linear(90, 90),
            nn.ReLU(True),
            nn.Linear(90, 90),
            nn.ReLU(True),
            nn.Linear(90, 90),
            nn.ReLU(True),
            nn.Linear(90, 90 // 2),
            nn.ReLU(True),
            nn.Linear(90 // 2, number_classes),  # B, number_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(1)
        return self.net(x)


class VanillaCNN(nn.Module):
    def __init__(self, number_classes: int, number_channels: int):
        """
        Standard ConvNet Architecture similar to the Locatello Disentanglement
        Library models.
        Args:
            number_classes: number of classes in the dataset
            number_channels: number channels of the input image
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(number_channels, 32, 4, 2, 1),  # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  8,  8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1),  # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 256, 4, 1),  # B, 256,  1,  1
            nn.ReLU(True),
            View((-1, 256 * 1 * 1)),  # B, 256
            nn.Linear(256, number_classes),  # B, number_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class DeeperCNN(nn.Module):
    def __init__(self, number_classes: int, number_channels: int):
        """
        Standard ConvNet Architecture similar to the Locatello Disentanglement
        Library models.
        Args:
            number_classes: number of classes in the dataset
            number_channels: number channels of the input image
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(number_channels, 32, 4, 2, 1),  # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  8,  8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 5),  # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 5),  # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1),  # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 256, 4, 1),  # B, 256,  1,  1
            nn.ReLU(True),
            View((-1, 256 * 1 * 1)),  # B, 256
            nn.Linear(256, 64),  # B, number_classes
            nn.ReLU(True),
            nn.Linear(64, number_classes),  # B, number_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# inspired by https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
# License: BSD
# Author: Ghassen Hamrouni
class SpatialTransformer(VanillaCNN):
    def __init__(self, number_classes: int, number_channels: int):
        super().__init__(number_classes=number_classes,
                         number_channels=number_channels)

        self.infer_affine_matrix_embedding = nn.Sequential(
            nn.Conv2d(number_channels, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(10, 10, kernel_size=6),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        # Infer 3 * 2 affine tranformation matrix
        self.infer_affine_matrix = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2))

        # Initialize the weights/bias with identity transformation
        self.infer_affine_matrix[2].weight.data.zero_()
        self.infer_affine_matrix[2].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def apply_spatial_transform(self, x: torch.tensor):
        batch_size = x.shape[0]
        xs = self.infer_affine_matrix_embedding(x)
        xs = xs.reshape(-1, 10 * 3 * 3)
        theta = self.infer_affine_matrix(xs)
        theta = theta.view(batch_size, 2, 3)

        grid = f.affine_grid(theta, x.size(), align_corners=False)
        x = f.grid_sample(x, grid)

        return x

    def forward(self, x: torch.tensor) -> torch.tensor:
        x = self.apply_spatial_transform(x)
        return self.net(x)


class RotationInvariantCNN(VanillaCNN):
    def __init__(self, number_classes: int, number_channels: int,
                 number_rotations: int = 8, feature_reduce_factor: int = 1):
        """
        Standard ConvNet Architecture similar to the Locatello Disentanglement
        Library models.
        Args:
            number_classes: number of classes in the dataset
            number_channels: number channels of the input image
        """
        super().__init__(number_classes=number_classes,
                         number_channels=number_channels)
        s = gspaces.Rot2dOnR2(number_rotations)  # r = number_rotations
        self.c_in = enn.FieldType(s, [s.trivial_repr] * number_channels)

        frf = feature_reduce_factor
        c_hidden_1 = enn.FieldType(s, [s.regular_repr] * int(32 / frf))
        c_hidden_2 = enn.FieldType(s, [s.regular_repr] * int(32 / frf))
        c_hidden_3 = enn.FieldType(s, [s.regular_repr] * int(64 / frf))
        c_hidden_4 = enn.FieldType(s, [s.regular_repr] * int(64 / frf))
        c_hidden_5 = enn.FieldType(s, [s.regular_repr] * int(256))
        self.net = nn.Sequential(
            enn.R2Conv(self.c_in, c_hidden_1, 4, 1, 2),  # B,  32*r, 32, 32
            enn.ReLU(c_hidden_1, True),
            enn.R2Conv(c_hidden_1, c_hidden_2, 4, 1, 2),  # B,  32*r, 16, 16
            enn.ReLU(c_hidden_2, True),
            enn.R2Conv(c_hidden_2, c_hidden_3, 4, 1, 2),  # B,  64*r,  8,  8
            enn.ReLU(c_hidden_3, True),
            enn.R2Conv(c_hidden_3, c_hidden_4, 4, 1, 2),  # B,  64*r,  4,  4
            enn.ReLU(c_hidden_4, True),
            enn.R2Conv(c_hidden_4, c_hidden_5, 4, 0, 1),  # B, 256*r,  1,  1
            enn.ReLU(c_hidden_5, True),
            enn.GroupPooling(c_hidden_5)  # 256*number_rotations -> 256
        )
        self.view = View((-1, 256 * 1 * 1))  # B, 256
        self.fc_out = nn.Linear(256, number_classes)  # B, number_classes

    def forward(self, x):
        x = enn.GeometricTensor(x, self.c_in)
        x = self.net(x).tensor  # from e2cnn tensor to pytorch tensor
        return self.fc_out(self.view(x))


class ImplicitMLP(nn.Module):
    def __init__(self, number_classes: int, number_channels: int,
                 image_dimension: int = 64, z_dimensions: int = 20):
        super().__init__()
        act_fct = nn.ReLU
        self.z_dimensions = z_dimensions
        # act_fct = Sine
        # act_fct = Hybrid
        self.net = nn.Sequential(
            nn.Linear(2 + number_channels, 40),
            act_fct(),
            nn.Linear(40, 40),
            act_fct(),
            nn.Linear(40, self.z_dimensions),
            act_fct(),
        )
        self.linear_0 = nn.Linear(self.z_dimensions, self.z_dimensions)
        self.read_out_activation = nn.ReLU()
        self.linear_1 = nn.Linear(self.z_dimensions, self.z_dimensions)
        self.read_out_activation = nn.ReLU()
        self.linear_2 = nn.Linear(self.z_dimensions, number_classes)

        self.image_dimension = image_dimension
        a, b = torch.meshgrid(torch.arange(self.image_dimension),
                              torch.arange(self.image_dimension))
        print(a.shape, b.shape)
        self.grid = torch.stack([a, b], dim=0).cuda()

    def forward(self, x: torch.Tensor):
        bs = x.shape[0]
        grid_broad = self.grid[None].repeat((bs, 1, 1, 1)) / 64.
        stacked_input = torch.cat([grid_broad, x], dim=1).permute(0, 2, 3, 1)
        out = self.net(stacked_input.flatten(0, 2))
        out = out.view(bs, -1, out.shape[-1])
        out = out.mean(dim=1)
        out = self.read_out_activation(self.linear_0(out))
        out = self.read_out_activation(self.linear_1(out))
        out = self.linear_2(out)
        return out

    def loss(self, logits, labels):
        return cross_entropy(logits, labels, reduction='mean')


class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class Hybrid(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        half = x.shape[1] // 2
        sin = torch.sin(self.w0 * x[:, :half])
        relu = torch.nn.functional.relu(x[:, half:])
        return torch.cat([sin, relu], dim=1)


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class CoordConvNet(VanillaCNN):
    def __init__(self, number_classes: int, number_channels: int):
        super().__init__(number_classes=number_classes,
                         number_channels=number_channels)

        pad = 2  # 2 is same padding
        self.net = nn.Sequential(
            CoordConv2d(number_channels, 16, 1),
            nn.ReLU(),
            nn.Conv2d(16, 16, 5, padding=pad),
            nn.ReLU(),
            nn.Conv2d(16, 16, 5, padding=pad),
            nn.ReLU(),
            nn.Conv2d(16, 16, 5, padding=pad),
            nn.ReLU(),
            nn.Conv2d(16, 16, 5, padding=pad),
            nn.ReLU(),
            nn.Conv2d(16, 32, 5, padding=pad),
            nn.ReLU(),
            nn.MaxPool2d(64 - 2 * 5 * (2 - pad)),
            Flatten(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, number_classes)
        )
