import torch
import numpy as np
import torchvision
from torchvision import transforms
from tqdm import tqdm
import argparse
import pretrainedmodels
import flowers102
import cub_200
import stanford_dogs
import chest_xray_dataset
import caltech_dataset
import oxford_pets

parser = argparse.ArgumentParser()
parser.add_argument('--seed',dest='seed',type=int)
parser.set_defaults(dataset=None)
args = parser.parse_args()

source_model_dataset = 'imagenet'
target_dataset_name = 'flowers102'

# def _convert_image_to_rgb(image):
#     return image.convert("RGB")

preprocess_fn = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
#     _convert_image_to_rgb,
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


models = {
    'resnet_50': torchvision.models.resnet50(pretrained=True),
    'resnet_152': torchvision.models.resnet152(pretrained=True),
    'mobilenet_v2': torchvision.models.mobilenet_v2(pretrained=True),
    'densenet_201': torchvision.models.densenet201(pretrained=True),
    'densenet_169': torchvision.models.densenet169(pretrained=True),
    'densenet_121': torchvision.models.densenet121(pretrained=True),
    'resnet_101': torchvision.models.resnet101(pretrained=True)
}

SEED = args.seed
torch.manual_seed(SEED)
np.random.seed(SEED)
print(SEED)

accuracies = np.zeros((len(models)))
best_accuracies = np.zeros((len(models)))
accuracies_50 = np.zeros((len(models)))
accuracies_75 = np.zeros((len(models)))
for k,model_name in tqdm(enumerate(models),total=len(models)):
    print(model_name)
                      
    #TODO: Remove y-=1 when you change away from flowers, also change number of classes of head
    train_ds = flowers102.Flowers102(root='/var/data/flowers102',split='train',transform=preprocess_fn)
    test_ds = flowers102.Flowers102(root='/var/data/flowers102',split='test',transform=preprocess_fn)
    
    print(len(train_ds))
    batch_size = 64
    train_loader = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=2)
    test_loader = torch.utils.data.DataLoader(test_ds,batch_size=batch_size,shuffle=False,num_workers=2)

    device = 'cuda'
    model = models[model_name]
    
    if('resnet' in model_name):
        model.fc = torch.nn.Linear(in_features=model.fc.in_features,out_features=102,bias=True)
    elif('mobilenet' in model_name):
        model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features,out_features=102,bias=True)
    elif('densenet' in model_name):
        model.classifier = torch.nn.Linear(in_features=model.classifier.in_features,out_features=102,bias=True)
    
    for param in model.parameters():
        param.requires_grad = True
        
        
    EPOCHS = 100
    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(filter(lambda p:p.requires_grad, model.parameters()),lr=0.0001, momentum=0.9)
    loss_fn = torch.nn.CrossEntropyLoss()
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for x,y in train_loader:
            x = x.to(device)
            y = y.to(device)
            y -= 1
            
            optimizer.zero_grad()

            out = model(x)
            loss = loss_fn(out,y)
            running_loss += loss
            loss.backward()
            optimizer.step()

        print('[%d] loss: %.3f' %
              (epoch + 1, running_loss / 2000))
        running_loss = 0.0

        total = 0
        correct = 0
        model.eval()
        with torch.no_grad():
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                y -= 1

                out = model(x)
                _,pred = torch.max(out.data,1)
                total += y.shape[0]
                correct += (pred==y).sum().item()

        if((epoch+1)==50):
            accuracies_50[k] = 100*correct/total
        
        if((epoch+1)==75):
            accuracies_75[k] = 100*correct/total
            
        if((epoch+1)==100):
            accuracies[k] = 100*correct/total
            torch.save(model.state_dict(), f'./models/imagenet_to_{target_dataset_name}_{model_name}_seed{SEED}_epoch100.pth')
            
        if(best_accuracies[k]<100*correct/total):
            torch.save(model.state_dict(), f'./models/imagenet_to_{target_dataset_name}_{model_name}_seed{SEED}_best.pth')
            
        best_accuracies[k] = max(best_accuracies[k],100*correct/total)
    
# with open(f'./accuracies/{source_model_dataset}_to_{target_dataset_name}_arch_finetune_accuracies_seed{SEED}_epoch100.npy','wb') as f:
#     np.save(f,accuracies)
    
# with open(f'./accuracies/{source_model_dataset}_to_{target_dataset_name}_arch_finetune_accuracies_seed{SEED}_epoch50.npy','wb') as f:
#     np.save(f,accuracies_50)
    
# with open(f'./accuracies/{source_model_dataset}_to_{target_dataset_name}_arch_finetune_accuracies_seed{SEED}_epoch75.npy','wb') as f:
#     np.save(f,accuracies_75)

# with open(f'./accuracies/{source_model_dataset}_to_{target_dataset_name}_arch_finetune_accuracies_seed{SEED}_best.npy','wb') as f:
#     np.save(f,best_accuracies)