import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import torch
import numpy as np
import cv2



def model_select(model):

    if model == 'vgg19':
        model=torchvision.models.vgg19(pretrained=True)
        print('================= SELECTED MODEL: vgg19 =================')
    elif model == 'vgg16':
        model = torchvision.models.vgg16(pretrained=True)
        print('================= SELECTED MODEL: vgg16 =================')
    elif model == 'resnet18':
        model = torchvision.models.resnet18(pretrained=True)
        print('================= SELECTED MODEL: res18 =================')
    elif model == 'resnet34':
        model = torchvision.models.resnet34(pretrained=True )
        print('================= SELECTED MODEL: res34 =================')
    elif model == 'GoogLeNet':
        model = torchvision.models.googlenet(pretrained=True)
        print('================= SELECTED MODEL: GoogleNet =================')
    else:
        raise Exception('unknown model')

    return model

#================================ImageNet==================================================================================
def ILSVRC_class_name_list_extabish(class_root):
    class_name_list=[]
    file=open(class_root)
    for i in file.readlines():
        class_name_list.append(i.split(':')[-1].split(',')[0].strip(' ').strip("'").strip('"'))
    file.close()
    return class_name_list

#load dataset
def ILSVRC2012_LOAD(root,batch_size=4,shuffle=True):
    print(' train_dataset \n train_dataset_loader \n val_dataset \n val_dataset_loader \n class_name_list')
    train_data_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),#ToTensor将模型的输出转化为【0，1】
            #transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 #std=[0.229, 0.224, 0.225])
        ])

    val_data_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         #std=[0.229, 0.224, 0.225])
            ])

    root_train=os.path.join(root,'train')
    root_val=os.path.join(root,'val')

    train_dataset =torchvision.datasets.ImageFolder(root=root_train,transform=train_data_transform)#1281167
    train_dataset_loader =DataLoader(train_dataset,batch_size, shuffle,num_workers=4)

    val_dataset = torchvision.datasets.ImageFolder(root=root_val,transform=val_data_transform)#320292
    val_dataset_loader = DataLoader(train_dataset,batch_size, shuffle,num_workers=4)
    print(' Dataset has been loaded successfully ')

    #load class name
    class_root=os.path.join(root,'class.txt')
    class_name_list=ILSVRC_class_name_list_extabish(class_root)

    return {'train_dataset':train_dataset,
            'train_dataset_loader':train_dataset_loader,
            'val_dataset':val_dataset,
            'val_dataset_loader':val_dataset_loader,
            'class_name_list':class_name_list}


dataset_dict=ILSVRC2012_LOAD(root='/home/liujiawei/local_pycharm/Data/ImageNet/val_seg')
class_name_list=dataset_dict['class_name_list']













#================================kitti==================================================================================

class KITTI_data_load(torch.utils.data.Dataset):
  '''第一个返回的是图像，第二个返回索引，第三个返回图片文件名
    第四个返回gtbox标签，第五个anchor，第六个返回置信度'''
  def __init__(self,img_path,ann_path,transform=None):
    self.img_path=img_path
    self.ann_path=ann_path
    self.transform=transform
    self.gtbox_label=np.load(self.ann_path+'/GtboxLabel.npy',allow_pickle=True)
    self.gtbox_two_point=np.load(self.ann_path+'/GtboxTwoPoint.npy',allow_pickle=True)
    self.gtbox_object_score=np.load(self.ann_path+'/GtboxObjectScore.npy',allow_pickle=True)
    self.file_list=np.load(self.ann_path+'/FileList.npy',allow_pickle=True)

  def __len__(self):
    return len(self.gtbox_label)#调用len(ObjectDetectionDataSet)会调用此方法

  def __getitem__(self,idx):
    img_name=self.file_list[idx].strip('txt')+'png'
    img=cv2.imread(self.img_path+'/'+img_name)
    img_label=self.gtbox_label[idx]
    img_gtbox_two_point=self.gtbox_two_point[idx]
    img_object_score=self.gtbox_object_score[idx]

    if self.transform:
      img=self.transform(img)
    return (img,idx)

class KITTI_ana_load(torch.utils.data.Dataset):
  '''第一个返回的是图像，第二个返回索引，第三个返回图片文件名
    第四个返回gtbox标签，第五个anchor，第六个返回置信度'''
  def __init__(self,img_path,ann_path,transform=None):
    self.img_path=img_path
    self.ann_path=ann_path
    self.transform=transform
    self.gtbox_label=np.load(self.ann_path+'/GtboxLabel.npy',allow_pickle=True)
    self.gtbox_two_point=np.load(self.ann_path+'/GtboxTwoPoint.npy',allow_pickle=True)
    self.gtbox_object_score=np.load(self.ann_path+'/GtboxObjectScore.npy',allow_pickle=True)
    self.file_list=np.load(self.ann_path+'/FileList.npy',allow_pickle=True)

  def __len__(self):
    return len(self.gtbox_label)#调用len(ObjectDetectionDataSet)会调用此方法

  def __getitem__(self,idx):
    img_name=self.file_list[idx].strip('txt')+'png'
    img=cv2.imread(self.img_path+'/'+img_name)
    img_label=self.gtbox_label[idx]
    img_gtbox_two_point=self.gtbox_two_point[idx]
    img_object_score=self.gtbox_object_score[idx]

    if self.transform:
      img=self.transform(img)
    return (img_name,img_label,img_gtbox_two_point,img_object_score)

class_name=['Car','Van','Truck','Pedestrian','Person_sitting','Cyclist','Tram','Misc','DontCare']




#==========================================================PROXY========================================================

def proxy_dataset_select(proxy_data):
    if proxy_data == 'ImageNet':
        print('================= selected proxy data type is ImageNet =================')
        dataset_dict=ILSVRC2012_LOAD(root='/home/liujiawei/local_pycharm/Data/ImageNet/val_seg')
        class_name_list=dataset_dict['class_name_list']
        dataset=dataset_dict['train_dataset']

    elif proxy_data == 'uniform':
        print('================= selected proxy data type is uniform =================')
        dataset = torch.ones(40000,3,224,224).uniform_(0.0,1.0)

    elif proxy_data == 'white':
        print('================= selected proxy data type is white =================')
        dataset = torch.ones(40000,3,224,224)

    elif proxy_data == 'black':
        print('================= selected proxy data type is black =================')
        dataset = torch.zeros(40000,3,224,224)

    elif proxy_data == 'ensemble':
        print('================= selected proxy data type is ensemble =================')
        ensemble_dataset = torch.ones(40000,3,224,224).uniform_(0.0,1.0)
        ensemble_dataset[:13333] = torch.ones(13333,3,224,224)
        ensemble_dataset[13333:13333+13333] = torch.zeros(13333,3,224,224)
        dataset = ensemble_dataset

    elif proxy_data == 'MSCOCO':
        print('================= selected proxy data type is MSCOCO =================')
        root = '/home/liujiawei/local_pycharm/Faster-RCNN/code/MSCOCO2017/coco/train2017'
        annFile = '/home/liujiawei/local_pycharm/Faster-RCNN/code/MSCOCO2017/coco/annotations/instances_train2017.json'
        dataset = torchvision.datasets.CocoDetection(root=root,annFile=annFile,
                    transform = transforms.Compose([transforms.Resize(size=[224,224])
                                                       ,transforms.ToTensor()]))
    elif proxy_data == 'KITTI':
        img_path='/home/liujiawei/local_pycharm/Faster-RCNN/code/Data/KITTI/training/image_2'
        ann_path='/home/liujiawei/local_pycharm/Faster-RCNN/code/Data/KITTI/annotation'

        dataset = KITTI_data_load(img_path,ann_path,transform=transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize((224,224)),#368=24x16 1248=78x16
                                    transforms.ToTensor()
                        ]))
    else:
        raise Exception('unknown proxy dataset')

    return dataset