import os
import argparse
import torch
from collections import OrderedDict

import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets


def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_args():
    parser = argparse.ArgumentParser("ResNet-50 on Imagenet")
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--gpu_nums", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=1.0)
    parser.add_argument("--weight_decay", type=float, default=1e-4)
    parser.add_argument("--momentum", type=float, default=0.0)
    parser.add_argument("--enlarge_factor", type=float, default=1.0)
    parser.add_argument("--Epochs", type=int, default=90)
    parser.add_argument("--output_dir", type=str, default="./output")
    parser.add_argument("--train_dataset_dir", type=str)

    args = parser.parse_args()
    args.tb_dir = os.path.join(args.output_dir, "tb_dir")
    return args

def load_checkpoint(args, path):
    states = torch.load(path, map_location=args.device)
    args.model.load_state_dict(states["model"])
    if "optimizer" in states.keys() and hasattr(args, "optimizer"):
        args.optimizer.load_state_dict(states["optimizer"])
    if "scheduler" in states.keys() and hasattr(args, "scheduler"):
        args.scheduler.load_state_dict(states["scheduler"])
    if "iter" in states.keys() and hasattr(args, "iter"):
        args.iter = states["iter"]

    return args


def save_checkpoint(args, path):
    state_dict = {}
    new_state_dict = OrderedDict()
    for k, v in args.model.state_dict().items():
        key = k
        if k.split(".")[0] == "module":
            key = k[7:]
        new_state_dict[key] = v
    state_dict["model"] = new_state_dict
    state_dict["iter"] = args.iter
    state_dict["optimizer"] = args.optimizer.state_dict()
    state_dict["scheduler"] = args.scheduler.state_dict()
    torch.save(state_dict, path)
    print("save {} successfully!".format(path))

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_train_loader(local_rank, batch_size=512, gpu_nums=8, dataset_dir="./datase"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    train_dataset = datasets.ImageNet(
        dataset,
        split = "val", 
        transform=transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    train_sampler = data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=gpu_nums,
        rank=local_rank,
        shuffle=True,
        drop_last=True
    )
 
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size // gpu_nums,
        shuffle=(train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True,
    )

    return DataIterator(train_loader)


def get_val_loader(dataset_dir="./dataset", batch_size=200):
    normalize = transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    )
    eval_dataset = datasets.ImageNet(
        dataset,
        split = "val", 
        transform=transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    eval_loader = data.DataLoader(
        eval_dataset,
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
    )

    return DataIterator(eval_loader) 


if __name__ == "__main__":
    from IPython import embed
    embed()