import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import os
import torchvision.transforms as transforms
import torchvision.models as models
from dataset import TargetSamplesImagenet, SamplesFromImNames


# Transformations
class TwoCropTransform:
    def __init__(self, transform, img_size):
        self.transform = transform
        self.img_size = img_size
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              transforms.ToTensor()])

    def __call__(self, x):
        return [self.transform(x), self.data_transforms(x)]

def rotation(input):
    batch = input.shape[0]
    target = torch.tensor(np.random.permutation([0,1,2,3] * (int(batch / 4) + 1)), device = input.device)[:batch]
    target = target.long()
    image = torch.zeros_like(input)
    image.copy_(input)
    for i in range(batch):
        image[i, :, :, :] = torch.rot90(input[i, :, :, :], target[i], [1, 2])

    return image, target


class FeatureExtractor():
    def __init__(self, model, name):
        super(FeatureExtractor, self).__init__()
        self.model = model
        self.activation = None
        self.name = name

    def get_activation(self, layer_name):
        self.layer_name = layer_name
        if self.name == 'resnet50':
            self.target_layer = getattr(self.model, self.layer_name)[-1]
        elif self.name == 'vgg19_bn':
            self.target_layer = getattr(self.model, 'features')[self.layer_name]
        elif self.name == 'densenet121':
            self.target_layer = getattr(self.model.features, self.layer_name)
        else:
            raise NotImplementedError()

        def hook(module, input, output):
            self.activation = output
        return self.target_layer.register_forward_hook(hook)
    



def normalize(_t, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]):
    t = _t + 0
    t[:, 0, :, :] = (t[:, 0, :, :] - mean[0]) / std[0]
    t[:, 1, :, :] = (t[:, 1, :, :] - mean[1]) / std[1]
    t[:, 2, :, :] = (t[:, 2, :, :] - mean[2]) / std[2]
    return t


def save_high_confidence_tar_samples_name(path_to_save, models_list, args, gpu_id=0):
    folders = os.listdir(args.match_dir)
    desired_folder = [folder for folder in folders if f'classID{args.match_target}_' in folder][0]
    print('desired_folder', desired_folder)
    # folder_name = f'classID{cls_id}_lowConf'
    tar_path = os.path.join(args.match_dir, desired_folder)
    data = TargetSamplesImagenet(tar_path, args.match_target)
    dataloader = torch.utils.data.DataLoader(data, batch_size=20, shuffle=False)

    all_pred_probs = torch.tensor([])
    for _, (imgs, lbl) in enumerate(dataloader):
        imgs, lbl = imgs.cuda(gpu_id), lbl.cuda(gpu_id)

        cls_pred_prob_list = []
        for model in models_list:
            logits = model(normalize(imgs))
            # print(torch.argmax(logits, dim=1))
            pred_prob = torch.softmax(logits, dim=1)
            cls_pred_prob  = pred_prob[:, args.match_target].data.cpu()
            cls_pred_prob_list.append(cls_pred_prob)

        cls_pred_prob = sum(cls_pred_prob_list)
        all_pred_probs = torch.cat((all_pred_probs, cls_pred_prob), dim=0)
    sorted_idx = torch.argsort(all_pred_probs, descending=True)

    desired_idx = sorted_idx
    desired_paths = []
    for idx in desired_idx:
        file = data.get_path(idx)
        name_im = file.split('/')[-1]
        desired_paths.append(name_im)

    text_to_write = "\n".join(desired_paths)
 

    with open(path_to_save, 'w') as file:
        file.write(text_to_write)
    print('save_high_confidence_tar_samples_name done......................')




def get_prune_model(model, amount=0.02, pruning_type='random'):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    if pruning_type == 'random':
        assert amount <= 0.03
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.RandomUnstructured,        # L1Unstructured, RandomUnstructured
            amount=amount,
        )
    else:
        assert amount >=0.5
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,           # L1Unstructured, RandomUnstructured
            amount=amount,
        )
    # Remove pruning reparametrization so the model is purely pruned
    for module, param_name in parameters_to_prune:
        prune.remove(module, param_name)
    return model



model_l = lambda model_name, _bool: models.__dict__[model_name](pretrained=_bool)
def get_pretrained_model(model_name):
    # path = 'path to model's weights'
    # model_path = os.path.join(path, f'{model_name}.pth')
    model = model_l(model_name, True)
    # model.load_state_dict(torch.load(model_path))
    return model
def get_model(model_name, surr_name, gpu_id = 0):
    if 'pruned' in model_name:
        path = f'./pruned_models/{surr_name}'
        os.makedirs(path, exist_ok=True)
        s_name = model_name.split('_')[0]
        model = get_pretrained_model(s_name)
        path = os.path.join(path, model_name+'.pth')
        if not os.path.exists(path):
            if 'pruned1' in model_name:
                model = get_prune_model(model, amount=0.6, pruning_type='l1unstructured')
            else:
                model = get_prune_model(model, amount=0.02, pruning_type='random')
            torch.save(model.state_dict(), path)
            print(f'getting {model_name} pruned model')
        model.load_state_dict(torch.load(path))
    else:
        model =  get_pretrained_model(model_name)
    if gpu_id == 0:
        print(f'{surr_name}: {model_name} model loaded')
    return model.eval()