import torch
import torchvision
from torch import nn
from utils import gather_from_all
from escnn import nn as enn
from escnn import gspaces
from equivision.models import c4resnet50, c4resnet18

class ESCNNEncoder(nn.Module):

    def __init__(self, arch='escnn18', use_gpool=False):
        super().__init__()
        if arch=='escnn18':
            self.backbone = c4resnet18(pretrained=False, use_gpool=use_gpool)
            if use_gpool:
                self.num_out_trivial_repr = self.backbone.num_out_regular_repr
            else:
                self.num_out_regular_repr = self.backbone.num_out_regular_repr
        elif arch=='escnn50':
            self.backbone = c4resnet50(pretrained=False, use_gpool=use_gpool)
            if use_gpool:
                self.num_out_trivial_repr = self.backbone.num_out_regular_repr
            else:
                self.num_out_regular_repr = self.backbone.num_out_regular_repr

        self.order = self.backbone.order
    
    def forward(self, x):
        x = self.backbone(x).tensor
        out = x.view(x.size(0), -1)    
        return out


class PredictorEqv(nn.Module):
    def __init__(self, in_type, hidden_type, out_type):
        super().__init__()
        self.net = enn.SequentialModule(
            enn.R2Conv(in_type, hidden_type, kernel_size=1, stride=1, padding=0),
            enn.InnerBatchNorm(hidden_type),
            enn.ReLU(hidden_type),
            enn.R2Conv(hidden_type, hidden_type, kernel_size=1, stride=1, padding=0),
            enn.InnerBatchNorm(hidden_type),
            enn.ReLU(hidden_type),
            enn.R2Conv(hidden_type, out_type, kernel_size=1, stride=1, padding=0),
            )

    def forward(self, x):
        return self.net(x)


class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, out_dim, bias=False),
                                 nn.BatchNorm1d(out_dim)
                                 )
    def forward(self, x):
        return self.net(x)
    

class PredictionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True), # hidden layer
                                 nn.Linear(hidden_dim, out_dim)) # output layer
    def forward(self, x):
        return self.net(x)
        

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x


class GuidedSimCLR(nn.Module):

    def __init__(self, args):
        super().__init__()

        if args.arch=='escnn18' or args.arch=='escnn50':
            self.backbone = ESCNNEncoder(args.arch, use_gpool=args.use_gpool)
            self.order = self.backbone.order

            if args.use_gpool:
                feature_dim = self.backbone.num_out_trivial_repr
            else:
                num_out_regular_repr = self.backbone.num_out_regular_repr   # resnet18
                feature_dim = num_out_regular_repr * self.order

                if args.beta:
                    self.gspace = gspaces.rot2dOnR2(N=self.order)
                    self.in_type = enn.FieldType(self.gspace, num_out_regular_repr*[self.gspace.regular_repr])
                    hidden_type = enn.FieldType(self.gspace, 512*[self.gspace.regular_repr])
                    out_type_eqv = enn.FieldType(self.gspace, [self.gspace.regular_repr])
                    self.predictor_eqv = PredictorEqv(self.in_type, hidden_type, out_type_eqv)

        elif args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512

        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048

        self.projector = ProjectionMLP(feature_dim, 2048, 128)
        if args.connector=='softmax':
            self.connector = torch.nn.Softmax(dim=1)
        elif args.connector=='identity':
            self.connector = Identity()
        elif args.connector=='tanh':
            self.connector = torch.nn.Tanh()
        elif args.connector=='shift':
            self.connector = None
            permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
            self.permute_tensor = torch.tensor(permute_patterns).cuda()

    def extract_guided_output(self, RX):
        b,c = RX.shape
        RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
        eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
        if self.connector:
            eqv_score = self.connector(eqv_logit)
            RX_re = RX.reshape([b, c//self.order, self.order])
            permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
        else:
            eqv_idx = torch.argmax(eqv_logit, dim=1)
            batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
            RX_re = RX.reshape([b, c//self.order, self.order])
            RX_re = RX_re.gather(2, batch_perm)
            HX = RX_re.reshape([b,c])

        out = self.projector(HX)
        return eqv_logit, out

    def forward(self, x1, x2, beta=0.1):
        RX1 = self.backbone(x1)
        RX2 = self.backbone(x2)

        if beta:
            eqv_logit1, out1 = self.extract_guided_output(RX1)
            eqv_logit2, out2 = self.extract_guided_output(RX2)
            ori_loss = cross_entropy_AB(eqv_logit1, eqv_logit2) / 2 + cross_entropy_AB(eqv_logit2, eqv_logit1) / 2
        else:
            out1 = self.projector(RX1)
            out2 = self.projector(RX2)

        con_loss = infoNCE(out1, out2) / 2 + infoNCE(out2, out1) / 2  

        loss = con_loss
        if beta:
            loss = loss + beta * ori_loss
        else:
            ori_loss = torch.tensor(0)

        return loss, con_loss, ori_loss


class GuidedSimSiam(nn.Module):

    def __init__(self, args):
        super().__init__()

        if args.arch=='escnn18' or args.arch=='escnn50':
            self.backbone = ESCNNEncoder(args.arch, use_gpool=args.use_gpool)
            self.order = self.backbone.order

            if args.use_gpool:
                feature_dim = self.backbone.num_out_trivial_repr
            else:
                num_out_regular_repr = self.backbone.num_out_regular_repr   # resnet18, resnet34
                feature_dim = num_out_regular_repr * self.order

                if args.beta:
                    self.gspace = gspaces.rot2dOnR2(N=self.order)
                    self.in_type = enn.FieldType(self.gspace, num_out_regular_repr*[self.gspace.regular_repr])
                    hidden_type = enn.FieldType(self.gspace, 512*[self.gspace.regular_repr])
                    out_type_eqv = enn.FieldType(self.gspace, [self.gspace.regular_repr])
                    self.predictor_eqv = PredictorEqv(self.in_type, hidden_type, out_type_eqv)

        elif args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512

        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048

        ######## simsiam application
        self.projector = ProjectionMLP(feature_dim, 2048, 2048)
        self.predictor = PredictionMLP(2048, 512, 2048)
        ########

        if args.connector=='softmax':
            self.connector = torch.nn.Softmax(dim=1)
        elif args.connector=='identity':
            self.connector = Identity()
        elif args.connector=='tanh':
            self.connector = torch.nn.Tanh()
        elif args.connector=='shift':
            self.connector = None
            permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
            self.permute_tensor = torch.tensor(permute_patterns).cuda()

        
    def extract_guided_output(self, RX):
        b,c = RX.shape
        RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
        eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
        if self.connector:
            eqv_score = self.connector(eqv_logit)
            RX_re = RX.reshape([b, c//self.order, self.order])
            permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
        else:
            eqv_idx = torch.argmax(eqv_logit, dim=1)
            batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
            RX_re = RX.reshape([b, c//self.order, self.order])
            RX_re = RX_re.gather(2, batch_perm)
            HX = RX_re.reshape([b,c])

        out = self.projector(HX)
        return eqv_logit, out

    def forward(self, x1, x2, beta=0.1):
        RX1 = self.backbone(x1)
        RX2 = self.backbone(x2)

        if beta:
            eqv_logit1, out1 = self.extract_guided_output(RX1)
            eqv_logit2, out2 = self.extract_guided_output(RX2)
            
            ori_loss = cross_entropy_AB(eqv_logit1, eqv_logit2) / 2 + cross_entropy_AB(eqv_logit2, eqv_logit1) / 2
                
        else:
            out1 = self.projector(RX1)
            out2 = self.projector(RX2)

        ######## simsiam application
        p1 = self.predictor(out1)
        p2 = self.predictor(out2)
        
        con_loss = negative_cosine_similarity_loss(p1, out2) / 2 + negative_cosine_similarity_loss(p2, out1) / 2
        ######## 
        
        loss = con_loss
        if beta:
            loss = loss + beta * ori_loss
        else:
            ori_loss = torch.tensor(0)

        return loss, con_loss, ori_loss


def infoNCE(nn, p, temperature=0.2):

    nn = torch.nn.functional.normalize(nn, dim=1)
    p = torch.nn.functional.normalize(p, dim=1)
    nn = gather_from_all(nn)
    p = gather_from_all(p)
    logits = nn @ p.T
    logits /= temperature
    n = p.shape[0]
    labels = torch.arange(0, n, dtype=torch.long).cuda()
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss

def CrossEntropy(score):
    score = gather_from_all(score)
    b = score.shape[0]
    target = torch.zeros([b], dtype=torch.int64).cuda()
    return torch.nn.functional.cross_entropy(score, target)

def cross_entropy_AB(A, B, eps=1e-6):
    A = gather_from_all(A)
    B = gather_from_all(B)

    b = A.shape[0]
    A = torch.softmax(A, dim=1)
    B = torch.softmax(B, dim=1)      
    loss = -(A * torch.log(B + eps)).sum(1)
    loss = loss.sum(0) / b
    return loss

def negative_cosine_similarity_loss(p, z):
    p = gather_from_all(p)
    z = gather_from_all(z)
    return - torch.nn.functional.cosine_similarity(p, z.detach(), dim=-1).mean()


    
