import logging, os, sys, gc, time, re
from datetime import datetime
import torch, random, numpy
from torchvision import datasets, transforms
from torch.utils.data import Dataset, TensorDataset
import torch.optim as optim
import matplotlib.pyplot as plt
import argparse

from sklearn.model_selection import KFold

from models.basic import *
from models.resnets import *
from models.vit import *

def get_timestamp():
    now = datetime.now()
    formatted_time = now.strftime('%Y-%m-%d %H:%M:%S')
    return formatted_time

def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
    os.makedirs(log_dir, exist_ok=True)

    logger = logging.getLogger(name)
    logger.setLevel(level)
    # Add file handler and stdout handler
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
    file_handler.setFormatter(formatter)
    # Add console handler.
    console_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(console_formatter)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    # Add google cloud log handler
    print('Log directory: ', log_dir)
    return logger, formatter


def get_loader(args, shuffle = True, use_cuda = True):
    train_kwargs = {'batch_size': args.batch_size, 'num_workers': 4}
    test_kwargs = {'batch_size': args.test_batch_size, 'num_workers': 4}

    if shuffle:
        train_kwargs.update({'shuffle': True})
        test_kwargs.update({'shuffle': False})
    else:
        train_kwargs.update({'shuffle': False})
        test_kwargs.update({'shuffle': False})

    if args.dataset == "MNIST":
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
            ])
        train_dataset = datasets.MNIST('../../autodl-tmp/data', train=True, download=True,
                        transform=transform)
        test_dataset = datasets.MNIST('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 1
    elif args.dataset == "EMNIST-Letters":
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.1724,), (0.3311,)),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1724,), (0.3311,))
                ])
        train_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "letters", train=True, download=True,
                        transform=transform)
        test_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "letters", train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 37
        num_channels = 1
    elif args.dataset == "EMNIST-Letters-shuffle":
        transform=transforms.Compose([
            transforms.ToTensor(),
            SwapImageHalves(),
            transforms.Normalize((0.1724,), (0.3311,))
            ])
        train_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "letters", train=True, download=True,
                        transform=transform)
        test_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "letters", train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 37
        num_channels = 1
    elif args.dataset == "EMNIST-Balanced":
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1753,), (0.3334,))
            ])
        train_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "balanced",  train=True, download=True,
                        transform=transform)
        test_dataset = datasets.EMNIST('../../autodl-tmp/data', split = "balanced",  train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 47
        num_channels = 1
    elif args.dataset == "FMNIST":
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.2860,), (0.3530,)),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.2860,), (0.3530,))
                ])
        train_dataset = datasets.FashionMNIST('../../autodl-tmp/data', train=True, download=True,
                        transform=transform)
        test_dataset = datasets.FashionMNIST('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 1
    elif args.dataset == "FC-FMNIST-Inv":
        outputfolder = "/root/autodl-tmp/datafree/"+ args.pre_train_ckpt[30:].replace("/","__") + "/inverse_100.0.pt"
        x = torch.load(outputfolder)
        y = torch.zeros(x.shape[0],dtype = torch.long)
        train_dataset = TensorDataset(x,y)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = None
        num_classes = 10
        num_channels = 1
    elif args.dataset == "FMNIST-shuffle":
        transform=transforms.Compose([
            transforms.ToTensor(),
            SwapImageHalves(),
            transforms.Normalize((0.2860,), (0.3530,))
            ])
        train_dataset = datasets.FashionMNIST('../../autodl-tmp/data', train=True, download=True,
                        transform=transform)
        test_dataset = datasets.FashionMNIST('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 1
    elif args.dataset == "Cifar10":
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(28),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        train_dataset = datasets.CIFAR10('../../autodl-tmp/data', train=True, download=True,
                        transform=transform)
        test_dataset = datasets.CIFAR10('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 3
    elif is_cifar10_shot(args.dataset):
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(28),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        shot = args.dataset.split("-")[1]
        if args.model == "vit":
            datasetpath = f'../../autodl-tmp/data/cifar-10-shot/cifar10_{shot}shot_224.pt'
        else:
            datasetpath = f'../../autodl-tmp/data/cifar-10-shot/cifar10_{shot}shot.pt'
        train_dataset = torch.load(datasetpath)
        test_dataset = datasets.CIFAR10('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_kwargs = {'batch_size': len(train_dataset), 'num_workers': 4}
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 3
    elif args.dataset == "Cifar10-imb":
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(28),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
        if args.model == "vit":
            datasetpath = f'../../autodl-tmp/data/cifar-10-imb/cifar10_10_224.pt'
        else:
            datasetpath = f'../../autodl-tmp/data/cifar-10-imb/cifar10_10.pt'
        train_dataset = torch.load(datasetpath)
        test_dataset = datasets.CIFAR10('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 3
    elif args.dataset == "Cifar100":
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(28),
            transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762))
            ])
        train_dataset = datasets.CIFAR100('../../autodl-tmp/data', train=True, download=True,
                        transform=transform)
        test_dataset = datasets.CIFAR100('../../autodl-tmp/data', train=False, download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 100
        num_channels = 3
    elif args.dataset == "DTD":
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Normalize((0.5283, 0.4738, 0.4231), (0.2689, 0.2596, 0.2669))
            ])
        train_dataset = datasets.DTD('../../autodl-tmp/data', split ="train", download=True,
                        transform=transform)
        test_dataset = datasets.DTD('../../autodl-tmp/data', split ="test", download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 47
        num_channels = 3
    elif args.dataset == "Pet":
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Normalize((0.4845, 0.4529, 0.3958), (0.2686, 0.2645, 0.2735))
            ])
        train_dataset = datasets.OxfordIIITPet('../../autodl-tmp/data', split ="trainval", download=True,
                        transform=transform)
        test_dataset = datasets.OxfordIIITPet('../../autodl-tmp/data', split ="test", download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 37
        num_channels = 3
    elif args.dataset == "SVHN":
        if args.model == "vit":
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.4381, 0.4442, 0.4734), (0.1983, 0.2013, 0.1972))
                ])
        else:
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(28),
                transforms.Normalize((0.4381, 0.4442, 0.4734), (0.1983, 0.2013, 0.1972))
                ])
        train_dataset = datasets.SVHN('../../autodl-tmp/data', split ="train", download=True,
                        transform=transform)
        test_dataset = datasets.SVHN('../../autodl-tmp/data', split ="test", download=True,
                        transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
        num_classes = 10
        num_channels = 3
    else:
        raise NotImplementedError

    return train_loader, test_loader, num_classes, num_channels

def get_model(args, mode = "P"):
    assert mode in ["P","C"]
    if args.model == "fc":
        model = FcNet(args)
        if args.finetune_flag and mode == "P":
            model = Fc_change_head(model, args.previous_num_classes)
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Fc_change_head(model, args.num_classes)
        if args.finetune_flag and mode == "C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "fc3d":
        model = FcNet3D(args)
        if args.finetune_flag and mode == "P":
            model = Fc_change_head(model, args.previous_num_classes)
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Fc_change_head(model, args.num_classes)
        if args.finetune_flag and mode == "C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "conv":
        model = ConvNet(args)
        if args.finetune_flag and mode == "P":
            model = Conv_change_head(model, args.previous_num_classes)
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Conv_change_head(model, args.num_classes)
        if args.finetune_flag and mode == "C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "conv3d":
        model = ConvNet3D(args)
        if args.finetune_flag and mode == "P":
            model = Conv_change_head(model, args.previous_num_classes)
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Conv_change_head(model, args.num_classes)
        if args.finetune_flag and mode == "C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "resnet18":
        model = create_model(get_model_name(args.pre_train_ckpt))
        if args.finetune_flag and mode =="P":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Resnet_change_tail(model, args.num_channels)
        model = Resnet_change_head(model, args.num_classes)
        if args.finetune_flag and mode =="C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "resnet18_sketch":
        model = create_model_sketch()
        if args.finetune_flag and mode =="P":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Resnet_change_head(model, args.num_classes)
    elif args.model == "vit":
        model = create_vit_model(get_vit_name(args.pre_train_ckpt))
        model = vit_change_head(model, args.num_classes)
        model = freeze_vit(model)
    elif args.model == "vit-mix":
        model = create_vit_mix()
        model = vit_change_head(model, args.num_classes)
        model = freeze_vit(model)
    elif args.model == "resnet18-EMNIST-Letters":
        model = create_model_EMNIST_Letters()
        if args.finetune_flag and mode =="P":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Resnet_change_tail(model, args.num_channels)
        model = Resnet_change_head(model, args.num_classes)
        if args.finetune_flag and mode =="C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    elif args.model == "resnet18-FMNIST":
        model = create_model_FMNIST()
        if args.finetune_flag and mode =="P":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
        model = Resnet_change_tail(model, args.num_channels)
        model = Resnet_change_head(model, args.num_classes)
        if args.finetune_flag and mode =="C":
            model.load_state_dict(torch.load(args.pre_train_ckpt, map_location= "cpu"))
    else:
        raise NotImplementedError

    return model

def get_submodel(model, args):
    if args.feature_index == -1:
        return model
    else:
        if args.model == "fc":
            return SubFcNet(model, args.feature_index)
        elif args.model == "fc3d":
            return SubFcNet3D(model, args.feature_index)
        elif args.model in ["resnet18", "resnet18-FMNIST", "resnet18-EMNIST-Letters"]:
            return SubResnet18(model, args.feature_index)
        elif args.model in ["detr", "detr-seg"]:
            from bin.detr.models.detr import SubDETR
            return SubDETR(model, args.feature_index)
        else:
            raise NotImplementedError

def randomness_control(seed):
    print("seed",seed)
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def plot_matrix(matrix, path):
    fig, ax = plt.subplots()
    cax = ax.imshow(matrix, cmap='inferno')
    fig.colorbar(cax)
    fig.savefig(path)

def get_filename(path):
    base_name = os.path.basename(path)  
    file_name_without_extension = os.path.splitext(base_name)[0]
    return file_name_without_extension

def get_num_class_from_name(path):
    if "/FMNIST/" in path:
        return 10
    elif "/EMNIST-Letters/" in path:
        return 37
    elif "/EMNIST-Balanced/" in path:
        return 47
    elif "/MNIST/" in path:
        return 10
    else:
        raise NotImplementedError

def measure_time_memory(f):
    def wrapped(*args, **kwargs):
        if torch.cuda.is_available():
            start_memory = torch.cuda.memory_allocated()
            torch.cuda.reset_max_memory_allocated()
        else:
            start_memory = 0

        start_time = time.time()

        result = f(*args, **kwargs)

        end_time = time.time()

        if torch.cuda.is_available():
            end_memory = torch.cuda.max_memory_allocated()
        else:
            end_memory = 0

        print(f"Function {f.__name__} executed in {end_time - start_time:.4f} seconds.")
        print(f"Memory usage increased by {(end_memory - start_memory) / (1024 ** 2):.2f} MB to {(end_memory) / (1024 ** 2):.2f} MB.")
        
        return result
    return wrapped

def get_real_parents(listp, listc):
    listp = [p.rsplit(".",1)[0] for p in listp]
    true_label = []
    for c in listc:
        flag = True
        for index, p in enumerate(listp):
            if p in c:
                true_label.append(index)
                flag = False
                break
        if flag:
            raise RuntimeError("Not True Label")
    return torch.tensor(true_label).long()

def is_cifar10_shot(s):
    pattern = r'^Cifar10-(1|5|10|20|50)$'
    return bool(re.match(pattern, s))

class SwapImageHalves(torch.nn.Module):
    """
    A transform class to swap the upper and lower halves of the image.
    """
    def __init__(self):
        super(SwapImageHalves, self).__init__()

    def forward(self, x):
        index = [14, 18, 11, 23, 17, 0, 3, 20, 27, 16, 24, 19, 8, 1, 22, 12, 25, 21, 10, 15, 4, 13, 7, 2, 5, 6, 9, 26]
        x = x[:,index,:]
        return x

def unfreeze_vit(model):
    for param in model.parameters():
        param.requires_grad = True
    print("Vit model unfreezed")
    return model

def get_optimizer(model, args):
    if args.optimizer == "Adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
    else:
        raise NotImplementedError
    return optimizer

def make_args(args):
    temp_args = argparse.Namespace()
    for key, value in vars(args).items():
        setattr(temp_args, key, value)
    temp_args.dataset = args.previous_dataset
    return temp_args