import os
import torch
import torchvision.transforms.functional as TF
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import argparse 
from my_models import model_dict
import numpy as np
import random
import json 
import PIL

# - - - - - - - - - - - - - - - - DATASET RELATED - - - - - - - - - - - - - - - - #


class Dataset(datasets.ImageFolder):
    def __init__(self, root):
        super().__init__(root=root)


        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.normalize = normalize

        trans_list = []
        trans_list.append( transforms.RandomResizedCrop(224, scale=(0.08, 1), ratio=(0.75, 1.3333333333333333)) )
        trans_list.append( transforms.ToTensor() )
        trans_list.append( normalize )

        self.transform = transforms.Compose(trans_list)


    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path) 
        img1 = self.transform(img)
        img2 = self.transform(img)

        return img1, img2, target


def get_val_loader(args):
    dataset = Dataset( os.path.join(args.data,'val') )
    val_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False)
    return val_loader



# - - - - - - - - - - - - - - - - LOADING RELATED - - - - - - - - - - - - - - - - #


def load_ckpt(model, ckpt):
    checkpoint = torch.load(ckpt)
    try:
        model.load_state_dict(checkpoint['state_dict'])
    except:
        model = torch.nn.parallel.DataParallel(model)
        model.load_state_dict(checkpoint['state_dict'])


def get_model(arch, ckpt):
    if ckpt == 'pretrained':
        model = model_dict[arch](pretrained=True)
    else:
        model = model_dict[arch]()
        load_ckpt(model, ckpt)
    model.eval()
    model.cuda()
    return model 


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

def main(args):
    
    model = get_model(args.arch, args.ckpt)
    val_loader = get_val_loader(args)

    per_class_total_count = torch.zeros(1000)
    per_class_agreement_count = torch.zeros(1000)


    for i, (images0, images1, target) in enumerate(val_loader):
        images0 = images0.cuda()
        images1 = images1.cuda()

        with torch.no_grad():
            output0 = model(images0)
            output1 = model(images1)
        conf0 = torch.softmax(output0['logits'], dim=1)
        conf1 = torch.softmax(output1['logits'], dim=1)
        
        for b in range( images0.shape[0] ):
            curr_class_idx = int(target[b])
            model0_pred = torch.argmax(conf0[b])
            model1_pred = torch.argmax(conf1[b])

            per_class_total_count[ curr_class_idx ] += 1 
            if  model0_pred == model1_pred:
                per_class_agreement_count[ curr_class_idx  ] += 1 
          

    per_class_agreement = per_class_agreement_count / per_class_total_count
    print('Overall test agreement for the augmentation is: ', per_class_agreement.mean()*100)
 





if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--data', metavar='DIR', default='PATH TO IMAGENET', help='path to dataset')
    parser.add_argument('--arch', metavar='ARCH', default='resnet18')
    parser.add_argument('--ckpt', default=None, type=str, help='ckpt for the model')
    parser.add_argument('--batch_size', default=256, type=int, metavar='N')
    parser.add_argument('--workers', default=16, type=int)
    args = parser.parse_args()

    seed = 0 
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


    main(args)


