import os
from subprocess import call
import numpy as np
import torch
from art.utils import load_cifar10, load_mnist
from tqdm import tqdm
import pickle

from timm.models import create_model

from torch import FloatTensor, div
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode


def assign_to_device(cuda_dev_id):
    """
    returns the device to use for computation, cuda if possible and the allocated request is correct, cpu otherwise
    :param cuda_dev_id: id of the desired cusa device, it could be a string or an int
    :return: the device to use for computation
    """
    if torch.cuda.is_available() is False:
        return torch.device("cpu")
    elif cuda_dev_id is None or int(cuda_dev_id) >= torch.cuda.device_count():
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        try:
            return torch.device("cuda:" + str(cuda_dev_id) if torch.cuda.is_available() else "cpu")
        except ValueError:
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(dataset_name, checkpoints_dir, device, transformer=False, ckpt_n=70000, model_type='ViT-B_16', checkpoint_name=None):
    from models.modeling import VisionTransformer, CONFIGS
    if dataset_name == 'cifar10':
        config = CONFIGS[model_type]
        path = '{}transformers/{}_checkpoint_{}.bin'.format(checkpoints_dir, dataset_name, ckpt_n)
        img_size = 224

        model = VisionTransformer(config, img_size=img_size, zero_head=True, num_classes=10).to(device)
        old_weights = model.head.weight.tolist()
        model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
        assert old_weights != model.head.weight.tolist()
    elif dataset_name == 'cifar100':
        config = CONFIGS[model_type]
        path = '{}transformers/{}_checkpoint_{}.bin'.format(checkpoints_dir, dataset_name, ckpt_n)
        img_size = 224

        model = VisionTransformer(config, img_size=img_size, zero_head=True, num_classes=100).to(device)
        old_weights = model.head.weight.tolist()
        model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
        assert old_weights != model.head.weight.tolist()
    elif dataset_name == 'tiny':
        model = create_model('vit_large_patch16_384', pretrained=False, drop_path_rate=0.1)
        model.reset_classifier(num_classes=200)
        checkpoint = torch.load('{}/transformers/{}'.format(checkpoints_dir, checkpoint_name))
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
    model.eval()
    return model


def load_data(dataset_name, transformer=False, data_dir='data', batch_size=100):
    from torchvision import transforms, datasets

    if dataset_name == 'cifar10':
        img_size = 224
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        transform_test = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        split = 'Test'
        dataset_test = datasets.CIFAR10(
            root=data_dir,
            train=True if split == "train" else False,
            download=True,
            transform=transform_train if split == "train" else transform_test,
        )

        x_test = torch.utils.data.DataLoader(dataset_test, batch_size=100, shuffle=False)

        split = 'train'
        dataset_train = datasets.CIFAR10(
            root=data_dir,
            train=True if split == "train" else False,
            download=True,
            transform=transform_train if split == "train" else transform_test,
        )

        x_train = torch.utils.data.DataLoader(
            dataset_train, batch_size=batch_size, shuffle=False
        )
        y_train, y_test = toCat_onehot(y_train=dataset_train.targets, y_test=dataset_test.targets, numclasses=10)
    elif dataset_name == 'cifar100':
        img_size = 224
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        transform_test = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        split = 'Test'
        dataset_test = datasets.CIFAR100(
            root=data_dir,
            train=True if split == "train" else False,
            download=True,
            transform=transform_train if split == "train" else transform_test,
        )

        x_test = torch.utils.data.DataLoader(dataset_test, batch_size=100, shuffle=False)

        split = 'train'
        dataset_train = datasets.CIFAR100(
            root=data_dir,
            train=True if split == "train" else False,
            download=True,
            transform=transform_train if split == "train" else transform_test,
        )

        x_train = torch.utils.data.DataLoader(
            dataset_train, batch_size=batch_size, shuffle=False
        )
        y_train, y_test = toCat_onehot(y_train=dataset_train.targets, y_test=dataset_test.targets, numclasses=100)
    elif dataset_name == 'tiny':
        img_size = 384
        with open('{}/train_dataset.pkl'.format(data_dir), 'rb') as f:
            train_data, train_labels = pickle.load(f)

        transform_train = transforms.Compose([
            transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC),
        ])
        train_dataset = ImageNetDataset(train_data, train_labels.type(torch.LongTensor), transform_train,
            normalize=transforms.Compose([
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)
                )
            ]),
        )
        x_train = DataLoader(
            train_dataset,
            shuffle=False,
            batch_size=batch_size,
            num_workers=8,
            pin_memory=True,
            drop_last=False,
        )
        y_train = train_labels.numpy()
        f.close()
        with open('{}/val_dataset.pkl'.format(data_dir), 'rb') as f:
            val_data, val_labels = pickle.load(f)
        transform_test = transforms.Compose([
            transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC),
        ])
        val_dataset = ImageNetDataset(val_data, val_labels.type(torch.LongTensor), transform_test,
            normalize=transforms.Compose([
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)
                ),
            ])
        )
        x_test = DataLoader(
            val_dataset,
            batch_size=100,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        y_test = val_labels.numpy()
        f.close()
        y_train, y_test = toCat_onehot(y_train=y_train.astype(int), y_test=y_test.astype(int), numclasses=200)
    else:
        exit("To consider", dataset_name)
    return x_train, y_train, x_test, y_test



def find_values(filename, l, loss):
    import re

    log_file = open(filename, "r")
    lines = log_file.readlines()
    T_0 = []
    T_1 = []

    i = 0
    for line in lines:
        s = line.strip()
        if 'Loss {}'.format(loss) in s:
            if (i % 2) == l:
                arr_0 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)
                T_0.append(float(arr_0[8]))
                T_1.append(float(arr_0[9]))
                print(arr_0[8], arr_0[9])
        i += 1
    return T_0, T_1


def get_output_bs_layer(layer, X, batch_size=100, desc=''):
    import math

    out_size = list(layer(torch.zeros(1, X.shape[1], X.shape[2], X.shape[3]).float().cuda()).shape)
    out_size[0] = X.shape[0]
    Y = np.zeros(tuple(out_size))

    preds = np.zeros(np.shape(Y))

    from tqdm.auto import tqdm
    for i in tqdm(range(math.ceil(X.shape[0] / batch_size)), position=0, leave=True, ncols=50, colour='red', ascii=True, desc=desc):
        preds[i * batch_size:(i + 1) * batch_size] = layer(
            torch.tensor(X[i * batch_size:(i + 1) * batch_size]).float().cuda()).cpu().detach().numpy()
    return preds


def get_prediction_transformers(args, model, train_loader, test_loader, X_test_adv):
    logits_train, y_train = extraction_transformers(args,
        train_loader, model, desc='train')
    logits_test = extraction_transformers(args, test_loader, model, desc='test')
    logits_adv = extraction_transformers(args, X_test_adv, model, desc='attack')

    return logits_train.detach().cpu().numpy(), logits_test.detach().cpu().numpy(), logits_adv.detach().cpu().numpy(), y_train.detach().cpu().numpy()


class ImageNetDataset(Dataset):
    """Dataset class for ImageNet"""
    def __init__(self, dataset, labels, transform=None, normalize=None):
        super(ImageNetDataset, self).__init__()
        assert(len(dataset) == len(labels))
        self.dataset = dataset
        self.labels = labels
        self.transform = transform
        self.normalize = normalize

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        if self.transform:
            data = self.transform(data)

        data = div(data.type(FloatTensor), 255)
        if self.normalize:
            data = self.normalize(data)

        return data, self.labels[idx]

def extraction_transformers(args, loader, model, desc, batch_size=100):
    from collections import defaultdict
    import math
    features = defaultdict(list)
    y_train = []
    if desc == 'attack':
        with torch.no_grad():
            num_batches = math.ceil(loader.shape[0] / batch_size)
            for i in tqdm(range(num_batches)):
                if i == 0:
                    data = torch.tensor(loader[i * batch_size: (i + 1) * batch_size]).to(args.device)
                    logits = model(data).detach().cpu()
                else:
                    data = torch.tensor(loader[i * 100: (i + 1) * 100]).to(args.device)
                    logits = torch.cat((logits, model(data).detach().cpu()), dim=0)
        return logits
    elif desc == 'train':
        with torch.no_grad():
            for index, (data, target) in tqdm(enumerate(loader), "Extracting features"):
                data = data.to(args.device)
                if index == 0:
                    data = data.to(args.device)
                    logits = model(data).detach().cpu()
                    y_train = torch.nn.functional.one_hot(torch.tensor(target), num_classes=logits.shape[1])
                else:
                    data = data.to(args.device)
                    logits = torch.cat((logits, model(data).detach().cpu()), dim=0)
                    y_train = torch.cat((y_train, torch.nn.functional.one_hot(torch.tensor(target), num_classes=logits.shape[1])),dim=0)
        return logits, y_train
    elif desc == 'test':
        with torch.no_grad():
            for index, (data, target) in tqdm(enumerate(loader), "Extracting features"):
                data = data.to(args.device)
                if index == 0:
                    data = data.to(args.device)
                    logits = model(data).detach().cpu()
                else:
                    data = data.to(args.device)
                    logits = torch.cat((logits, model(data).detach().cpu()), dim=0)
        return logits


def save_status(probs, path, attack):
    os.makedirs(path, exist_ok=True)
    print('{}/probs_{}.npy'.format(path, attack))
    np.save('{}/probs_{}.npy'.format(path, attack), probs)


def load_status(attack, path):
    return np.load('{}/probs_{}.npy'.format(path, attack))


def print_current_time(s=None):
    from datetime import datetime

    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    if s is not None:
        print(s + " : " + current_time)
    else:
        print(current_time)
        
def get_prediction_by_bs(model, X, num_classes, batch_size=500):
    import math
    preds = np.zeros((X.shape[0], num_classes))
    for i in range(math.ceil(X.shape[0] / batch_size)):
        preds[i * batch_size:(i + 1) * batch_size] = model(
            torch.tensor(X[i * batch_size:(i + 1) * batch_size]).float().cuda()).cpu().detach().numpy()
    return preds
    
def toCat_onehot(y_train, y_test, numclasses):
    y_train = np.eye(numclasses)[y_train]
    y_test = np.eye(numclasses)[y_test]

    return y_train, y_test
