
import torch.nn as nn
import torch
import torchvision.models as models
import os

class SimCLR(nn.Module):
    def __init__(self, backbone, projection_dim=128, input_channel = 3):
        super().__init__()
        assert backbone in ["resnet18", "resnet34", "wide_resnet50_2"]
        if backbone == "resnet18":
            self.model_base = models.resnet18(pretrained=False)
        elif backbone == "resnet34":
            self.model_base = models.resnet34(pretrained=False)
        elif backbone == "wide_resnet50_2":
            self.model_base = models.wide_resnet50_2(pretrained=False)
        self.feature_dim = self.model_base.fc.in_features

        # Customize for CIFAR10. Replace conv 7x7 with conv 3x3, and remove first max pooling.
        self.model_base.conv1 = nn.Conv2d(input_channel, 64, 3, 1, 1, bias=False)
        self.model_base.maxpool = nn.Identity()
        self.model_base.fc = nn.Identity()  # remove final fully connected layer.

        # Add MLP projection.
        self.projector = nn.Sequential(nn.Linear(self.feature_dim, 2048),
                                       nn.ReLU(),
                                       nn.Linear(2048, projection_dim))

    def forward(self, x):
        feature = self.model_base(x)
        feature_proj = self.projector(feature)
        return feature, feature_proj


class LinearClassiferModel(nn.Module):
    """Linear wrapper of encoder."""
    def __init__(self, encoder: nn.Module, feature_dim, n_classes = 10):
        super().__init__()
        self.model_base = encoder
        self.feature_dim = feature_dim
        self.n_classes = n_classes
        self.lin = nn.Linear(self.feature_dim, self.n_classes)

    def forward(self, x):
        return self.lin(self.model_base(x))


def load_ckp(model, optimizer, lr_scheduler, stage, folder = "inProc_data", flag_load = "123"):
    def getFilesInPath(path, stage, suffix):
        name_list = []
        f_n_list = sorted(os.listdir(path))
        for f_n in f_n_list:
            if suffix in os.path.splitext(f_n)[1] and stage in os.path.splitext(f_n)[0]:
                pathName = path + "/" + f_n
                name_list.append(pathName)
        return name_list
    file_ns = getFilesInPath(folder, stage, "pt")
    file_ns = [n for n in file_ns if flag_load in n] if flag_load is not None else file_ns
    if len(file_ns) > 0:
        checkpoint = torch.load(file_ns[-1])
        model.load_state_dict(checkpoint['net'])    #, strict=False)  #ignore the unmatched key
        if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer'])
        if lr_scheduler is not None: lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        epoch_start = checkpoint['epoch']
        str_record = "load ckp - %s, epoch_start = %d" % (file_ns[-1], epoch_start)
        print(str_record)
    else:
        epoch_start = 0
    return model, optimizer, lr_scheduler, epoch_start



os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_sim = SimCLR("resnet18").to(device)
model_classifar = LinearClassiferModel(model_sim.model_base, model_sim.feature_dim, 10).to(device)
model_classifar, _, _, _ = load_ckp(model_classifar, None, None, stage = "RdSm-train", flag_load = "2022-01-06_23-44-59")

print("OK")
