import torch
import sys
import os
import yaml
import pandas as pd
import argparse
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import models
import time
from .data import load_data



class MLP(torch.nn.Module):
    def __init__(self, out_dim=10):
        super(MLP, self).__init__()
        self.backdone = nn.Sequential(
            nn.Linear(2048,512), 
            nn.ReLU(),
            nn.Linear(512,256), 
            nn.ReLU(),
            nn.Linear(256,out_dim), 
        )
    def forward(self, x):
        y = self.backdone(x)
        return y


class myResNet18(nn.Module):
    def __init__(self, pre_tra=True):
        super(myResNet18, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT', num_classes=1000)
        else: 
            self.backbone = torchvision.models.resnet18(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()
        self.fc1 =  nn.Linear(512,2048)
    def forward(self, x):
        h = self.backbone(x)
        h = h.view(h.size(0), -1)  
        y = self.fc1(h)
        return y

class myResNet34(nn.Module):
    def __init__(self, pre_tra=True):
        super(myResNet34, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.resnet34(weights='ResNet3_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet34(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()
        self.fc1 =  nn.Linear(512,2048)
    def forward(self, x):
        h = self.backbone(x)
        h = h.view(h.size(0), -1)  
        y = self.fc1(h)
        return y

class myResNet50(nn.Module):
    def __init__(self, pre_tra=True):
        super(myResNet50, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.resnet50(weights='ResNet50_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet50(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()    
    def forward(self, x):
        y = self.backbone(x)
        return y


class myResNet101(nn.Module):
    def __init__(self, pre_tra=True):
        super(myResNet101, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.resnet101(weights='ResNet101_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet101(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()    
    def forward(self, x):
        y = self.backbone(x)
        return y

class myDenseNet121(nn.Module):
    def __init__(self, pre_tra=True):
        super(myDenseNet121, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.densenet121(weights='DenseNet121_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet101(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()
        self.fc1 =  nn.Linear(1000,2048)
    def forward(self, x):
        h = self.backbone(x)
        h = h.view(h.size(0), -1)  
        y = self.fc1(h)
        return y

class myVGG16(nn.Module):
    def __init__(self, pre_tra=True):
        super(myVGG16, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.vgg16(weights='VGG16_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet101(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()  
        self.fc1 =  nn.Linear(1000,2048)
    def forward(self, x):
        h = self.backbone(x)
        h = h.view(h.size(0), -1)  
        y = self.fc1(h)
        return y

class myVIT_b(nn.Module):
    def __init__(self, pre_tra=True):
        super(myVIT_b, self).__init__()
        if pre_tra == True:
            self.backbone = torchvision.models.vit_b_16(weights='ViT_B_16_Weights.DEFAULT', num_classes=1000)
        else:
            self.backbone = torchvision.models.resnet101(num_classes=1000)
        self.backbone.fc = nn.Identity()
        self.backbone.maxpool = nn.Identity()  
        self.fc1 =  nn.Linear(1000,2048)
    def forward(self, x):
        h = self.backbone(x)
        h = h.view(h.size(0), -1)  
        y = self.fc1(h)
        return y

    

def Model_zoo(arch, pre_tra=True):
    if arch == 'res18':
        model = myResNet18(pre_tra=pre_tra)
    elif arch == 'res34':
        model = myResNet34(pre_tra=pre_tra)
    elif arch == 'res50':
        model = myResNet50(pre_tra=pre_tra)
    elif arch == 'res101':
        model = myResNet101(pre_tra=pre_tra)
    elif arch == 'dense121':
        model = myDenseNet121(pre_tra=pre_tra)
    elif arch == 'VGG16':
        model = myVGG16(pre_tra=pre_tra)
    elif arch == 'VIT_b':
        model = myVIT_b(pre_tra=pre_tra)

    return model



class Projection(torch.nn.Module):
    def __init__(self):
        super(Projection, self).__init__()
        self.backdone = nn.Sequential(
            nn.Linear(2048,512), 
            nn.ReLU(),
            nn.Linear(512,256), 
            nn.ReLU(),
            nn.Linear(256,256), 
        )
    def forward(self, x):
        y = self.backdone(x)
        return y



def load_pretrain(encoder, device, arch):

    if encoder == 'moco':

        model = torchvision.models.resnet50()

        for name, param in model.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                param.requires_grad = False

        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()

        encoder_path = "./moco_official.pth.tar"

        checkpoint = torch.load(encoder_path, map_location=device)

        state_dict = checkpoint['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]

        model.load_state_dict(state_dict, strict=False)
        model.fc = nn.Identity() 
        model.maxpool = nn.Identity()


    elif encoder == 'simclr':

        model = resnet50x1()
        model = model.to(device)
        encoder_path = './simclr_official.pth'
        checkpoint = torch.load(encoder_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        model.fc = nn.Identity() 
        model.maxpool = nn.Identity()
        
    elif encoder == 'supervised':
        
        model = Model_zoo(arch)

    return model






def load_encoder(arch, encoder_path, device):
    if arch == 'res18':
        backbone = myResNet18()
        backbone.load_state_dict(torch.load(encoder_path, map_location=device)) 
    elif arch == 'res34':
        backbone = myResNet34()
        backbone.load_state_dict(torch.load(encoder_path, map_location=device)) 
    elif arch == 'res50':
        backbone = myResNet50()
        backbone.load_state_dict(torch.load(encoder_path, map_location=device))  
    elif arch == 'res101':
        backbone = myResNet101()
        backbone.load_state_dict(torch.load(encoder_path, map_location=device))

    return backbone



























