'''
Produces the ImageNet evaluation server submissions for the test set.
'''

import argparse
import json
import math
import os

from PIL import Image
import timm
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from config import path_imagenet_test, batch_size, workers


def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


class dataset_ImageNet(Dataset):
    def __init__(self, path, transform=None):
        self.paths = sorted([path+i for i in os.listdir(path)])
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        image = pil_loader(self.paths[idx])
        if self.transform:
            image = self.transform(image)
        return image


@torch.no_grad()
def infer(model, path_logits, device = 0):
    net = timm.create_model(model, pretrained=True).to(device).eval()
    config = net.default_cfg
    if 'test_input_size' in config:
        input_size = config['test_input_size']
        print('Using test input size',input_size)
    else: input_size = config['input_size']
    if config['interpolation'] == 'bicubic':
        interpolation = transforms.InterpolationMode.BICUBIC
    else: interpolation = transforms.InterpolationMode.BILINEAR
    tf = transforms.Compose(
        [transforms.Resize(int(math.floor(input_size[-1] / config['crop_pct'])), interpolation=interpolation),
         transforms.CenterCrop(input_size[-1]),
         transforms.ToTensor(),
         transforms.Normalize(config['mean'], config['std'])
         ])
    print('Starting model',model,'with transform:\n',tf)
    dataset = dataset_ImageNet(path=path_imagenet_test, transform=tf)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    logits = []
    for data in dataloader:
        images = data.to(device)
        logits.append(net(images).to('cpu'))
    logits = torch.cat(logits,0)
    torch.save(logits, path_logits+model+'.pt')
    print(f'Inference for {model} finished!')


def inference_data(model, path_logits, path_infer):
    logits = torch.load(path_logits+model+'.pt')
    marginS = F.softmax(logits, dim=1).topk(2,1)[0]
    marginL = logits.topk(2,1)[0]
    ea = [logits.topk(5, 1)[1].tolist(),
          (-(F.softmax(logits, dim=1)*F.log_softmax(logits, dim=1)).sum(dim=1)).tolist(),
          torch.max(F.softmax(logits, dim=1), 1)[0].tolist(),
          (marginS[:,0]-marginS[:,1]).tolist(),
          (marginL[:,0]-marginL[:,1]).tolist()]
    with open(path_infer+model+'.txt', 'w') as f: json.dump(ea,f,indent=2)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--skip', action='store_true', help='skip inference, use when already done')
    args = parser.parse_args()
    
    if not os.path.exists('data/test/submissions'):
        os.makedirs('data/test/submissions')
    
    # infer the test set
    if not args.skip:
        models = ['tf_mobilenetv3_small_075',
                  'mobilenetv3_large_100_miil',
                  'tf_efficientnet_b3_ns',
                  'tf_efficientnet_b4_ns',
                  'beit_large_patch16_224',
                  'tf_efficientnet_l2_ns_475']
        path_logits = 'data/test/test_logits_'
        path_infer = 'data/test/test_infer_'
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for i in models:
            infer(i, path_logits, device)
            inference_data(i, path_logits, path_infer)
    
    # to map the labels from default sorted to ImageNet challenge order
    with open('data/test/map_test.txt', 'r') as f: map_test = json.load(f)
    
    ''' submission 1
    max softmax maximum improvement
    'tf_efficientnet_b4_ns', 4492947056
    'beit_large_patch16_224', 61603132416
    threshold: 0.5657278299331665
    '''
    
    thresh = 0.5657511353492737
    with open('data/test/test_infer_tf_efficientnet_b4_ns.txt', 'r') as f: tEN4 = json.load(f)
    with open('data/test/test_infer_beit_large_patch16_224.txt', 'r') as f: tB1 = json.load(f)
    t1 = [tEN4[0][i] if tEN4[2][i]>=thresh else tB1[0][i] for i in range(len(tEN4[0]))]
    with open('data/test/submissions/submission-1-EN4-BL1.txt', 'w') as f:
        f.write('\n'.join([' '.join(i) for i in [[str(map_test[z]) for z in y] for y in t1]]))
    
    ''' submission 2
    max softmax maximum improvement
    'tf_mobilenetv3_small_075', 46016336
    'mobilenetv3_large_100_miil', 225436416
    threshold: 0.45507436990737915
    '''
    
    thresh = 0.45507436990737915
    with open('data/test/test_infer_tf_mobilenetv3_small_075.txt', 'r') as f: tMXS = json.load(f)
    with open('data/test/test_infer_mobilenetv3_large_100_miil.txt', 'r') as f: tML = json.load(f)
    t2 = [tMXS[0][i] if tMXS[2][i]>=thresh else tML[0][i] for i in range(len(tMXS[0]))]
    with open('data/test/submissions/submission-2-MXS-ML.txt', 'w') as f:
        f.write('\n'.join([' '.join(i) for i in [[str(map_test[z]) for z in y] for y in t2]]))
    
    ''' submission 3
    max softmax maximum improvement
    'mobilenetv3_large_100_miil', 225436416
    'tf_efficientnet_b3_ns', 1874915424
    threshold: 0.7111721634864807
    '''
    
    thresh = 0.7111721634864807
    with open('data/test/test_infer_mobilenetv3_large_100_miil.txt', 'r') as f: tML = json.load(f)
    with open('data/test/test_infer_tf_efficientnet_b3_ns.txt', 'r') as f: tEN3 = json.load(f)
    t3 = [tML[0][i] if tML[2][i]>=thresh else tEN3[0][i] for i in range(len(tML[0]))]
    with open('data/test/submissions/submission-3-ML-EN3.txt', 'w') as f:
        f.write('\n'.join([' '.join(i) for i in [[str(map_test[z]) for z in y] for y in t3]]))
    
    ''' submission 4
    max softmax maximum improvement
    'tf_efficientnet_b4_ns', 4492947056
    'beit_large_patch16_224', 61603132416
    'tf_efficientnet_l2_ns_475', 172113352288
    threshold 1: 0.6299011707305908
    threshold 2: 0.49049681425094604
    '''
    
    thresh1 = 0.6299011707305908
    thresh2 = 0.49049681425094604
    with open('data/test/test_infer_tf_efficientnet_b4_ns.txt', 'r') as f: t1 = json.load(f)
    with open('data/test/test_infer_beit_large_patch16_224.txt', 'r') as f: t2 = json.load(f)
    with open('data/test/test_infer_tf_efficientnet_l2_ns_475.txt', 'r') as f: t3 = json.load(f)
    t = [t1[0][i] if t1[2][i]>=thresh1 else (t2[0][i] if t2[2][i]>=thresh2 else t3[0][i]) for i in range(len(t1[0]))]
    with open('data/test/submissions/submission-3-EN4-BL1-EL2.txt', 'w') as f:
        f.write('\n'.join([' '.join(i) for i in [[str(map_test[z]) for z in y] for y in t]]))


if __name__ == '__main__':
    main()