import torch
from torch import nn
from models.initializer import initialize_model
from SSL.model_aug import Augmenter

def get_backbone(config):
    from models.pretrained import model_url, get_enc
    if config.weights_model in ['clip'] + list(model_url.keys()):  # already handled
        backbone, rep_dim = get_enc(config, config.weights_model, get_rep_dim=True)
        return backbone, rep_dim
    
    backbone = initialize_model(config, d_out=config.feat_dim)  # d_out can be None; avoiding getting featurizers since we dont have them for all models
    rep_dim = backbone.fc.weight.shape[1]
    backbone.fc = nn.Identity()
    from train_ssl import initialize_encoder
    initialize_encoder(config, backbone)
    return backbone, rep_dim

def drop_adapter(config, adapter):
    if config.weights_model != 'clip':
        adapter.fc = torch.nn.Identity()
    else:
        adapter.attnpool[1] = torch.nn.Identity()

class Adapter(nn.Module):
    def __init__(self, input_dim, hidden_dim): 
        super().__init__()
        self.arch = nn.Sequential(
            nn.Linear(input_dim , hidden_dim), 
            nn.BatchNorm1d(hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim , input_dim) )
    
    def forward(self, x): 
        return self.arch(x.float())


class SimCLR_projection_MLP(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        hidden_dim = in_dim
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x 

class SimSiam_projection_MLP(nn.Module):
    def __init__(self, in_dim, out_dim, num_layers=2):
        super().__init__()
        hidden_dim = out_dim if num_layers > 1 else in_dim
        self.num_layers = num_layers

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.BatchNorm1d(out_dim, affine=False)  # Page:5, Paragraph:2
        )

    def forward(self, x):
        if self.num_layers == 1: 
            x = self.layer3(x) # since hidden_dim is out_dim this is fine
        if self.num_layers == 2:
            x = self.layer1(x)
            x = self.layer3(x)
        elif self.num_layers == 3:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)

        return x

def get_projector(model_name, in_dim, out_dim, num_layers=2):
    if model_name == 'simsiam':
        return SimSiam_projection_MLP(in_dim, out_dim, num_layers)
    elif model_name == 'simclr':
        return SimCLR_projection_MLP(in_dim, out_dim)
    else:
        print(f'no projector found for {model_name}')


class prediction_MLP(nn.Module):
    def __init__(self, in_dim=2048):
        super().__init__()
        out_dim = in_dim
        hidden_dim = int(out_dim / 4)

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)

        return x


class TSSL(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.name = args.algorithm
        backbone, self.rep_dim = get_backbone(args)
        last_layer_name = 'fc' if args.weights_model != 'clip' else 'attnpool.0'
        if args.adapt:
            print('added adaptor to backbone')
            last_layer_name = 'fc' if args.weights_model != 'clip' else 'attnpool.0'
            # freeze backbone
            for name, param in backbone.named_parameters():
                if name not in [f'{last_layer_name}.weight', f'{last_layer_name}.bias'] and not name.startswith(last_layer_name):
                    param.requires_grad = False

            for name, param in backbone.named_parameters():
                if name not in [f'{last_layer_name}.weight', f'{last_layer_name}.bias']:
                    param.requires_grad = False
            adapter = Adapter(input_dim=self.rep_dim, hidden_dim=128).cuda()
            if args.weights_model == 'clip':
                backbone.attnpool = nn.Sequential(backbone.attnpool, adapter)
            backbone.fc = adapter
        else:
            backbone.fc = nn.Identity()

        projector = get_projector(self.name, self.rep_dim, args.feat_dim, args.num_proj_layers)  # Note: to access projector, just use self.encoder[1]
        
        self.encoder = nn.Sequential(
            backbone,
            projector
        )

        self.predictor = prediction_MLP(args.feat_dim)

        self.model_aug = args.model_aug
        if self.model_aug:
            self.augmenter = Augmenter(args, None)
            self.tr_encoder = self.encoder
    
    def reinitialize_backbone(self, config):
        self.encoder[0] = get_backbone(config)[0]

    def update_tr_encoder(self):
        self.tr_encoder = self.augmenter.augment(self.encoder)
    
    def _forward_aug_feats(self, x1, x2):
        if self.model_aug:
            res = self.tr_encoder(x1), self.tr_encoder(x2)
            # del tr_encoder
            torch.cuda.empty_cache()
            return res
        else:
            return self.encoder(x1), self.encoder(x2)

    def forward(self, x1, x2):
        pass


class SimSiam(TSSL):
    def __init__(self, args):
        super().__init__(args=args)

    def forward(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        if self.model_aug:
            z1, z2 = self._forward_aug_feats(x1, x2)

        return {'z1': z1.detach(), 'z2': z2.detach(), 'p1': p1, 'p2': p2}


class SimCLR(TSSL):
    def __init__(self, args):
        super().__init__(args=args)
        self.stop_grad = args.stop_grad
    
    def forward(self, x1, x2):
        z1, z2 = self._forward_aug_feats(x1, x2)
        return {'z1': z1, 'z2': z2}

    def _forward_aug_feats(self, x1, x2):
        if self.model_aug:
            if self.stop_grad:
                res = self.tr_encoder(x1).detach(), self.encoder(x2)
            else:
                res = res = self.tr_encoder(x1), self.tr_encoder(x2)
            # del tr_encoder
            torch.cuda.empty_cache()
            return res
        else:
            if self.stop_grad:
                return self.encoder(x1).detach(), self.encoder(x2)
            return self.encoder(x1), self.encoder(x2)

