import os
import random
import signal
import subprocess
import pickle
import sys
import time
import argparse
from functools import partial

from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data 
from torchvision import models
from tqdm import tqdm

import numpy as np


class LoadedResNet(nn.Module):
    def __init__(self, model_name, pretrained_weights, gpu):
        super().__init__()
        self.model_name = model_name
        print('MODEL NAME', model_name)

        # if model_name == 'supervised':
        #     self.model = models.resnet50(pretrained=True, weights=pretrained_weights).cuda(gpu)
        #     del self.model.fc
        # else:
        self.model = models.resnet50(pretrained=False).cuda(gpu)
        del self.model.fc

        state_dict = torch.load(pretrained_weights, map_location=torch.device('cpu'))
        if 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        elif 'model' in state_dict:
            state_dict = state_dict['model']
        elif 'resnet' in state_dict:
            state_dict = state_dict['resnet']
        
        if model_name == 'simclr-v2':
            state_dict_keys = [k for k in state_dict.keys() if 'tracked' not in k and 'fc' not in k]
            simclr_dict = pickle.load(open('./simclr_keys.pkl', 'rb'))
            state_dict = {key: state_dict[tf_key] for key, tf_key in simclr_dict.items()}
        else:
            state_dict = self.rename(state_dict)
            state_dict = self.remove_keys(state_dict)

        
        state_dict = self.rename(state_dict)
        state_dict = self.remove_keys(state_dict)
        self.model.load_state_dict(state_dict)


        self.model.eval()
        print("num parameters:", sum(p.numel() for p in self.model.parameters()))

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)

        return x

    def rename(self, d):
        unwanted_prefixes = {
            'supervised': '',
            'simclr-v1': '',
            'simclr-v2': '',
            'byol': '',
            'swav': 'module.',
            'deepcluster-v2': 'module.',
            'sela-v2': 'module.',
            'moco-v1': 'module.encoder_q.',
            'moco-v2': 'module.encoder_q.',
            'cmc': 'module.encoder1.',
            'infomin': 'module.encoder.',
            'insdis': 'module.encoder.',
            'pirl': 'module.encoder.',
            'pcl-v1': 'module.encoder_q.',
            'pcl-v2': 'module.encoder_q.',
        }
        prefix = unwanted_prefixes[self.model_name]
        l = len(prefix)
        new_d = {}
        for key in d.keys():
            if prefix in key:
                new_d[key[l:]] = d[key]
            else:
                new_d[key] = d[key]
        return new_d

    def remove_keys(self, d):
        for key in list(d.keys()):
            if 'module.jigsaw' in key or 'module.head_jig' in key:
                print('warning, jigsaw stream in model')
                d.pop(key)
            elif 'projection' in key or 'prototypes' in key or 'fc' in key or 'linear' in key or 'head' in key:
                print(f'removed {key}')
                d.pop(key)
        return d


def get_model(gpu, args):
    model = LoadedResNet(args.model_name, args.pretrained_weights, gpu)
    return model


def main_worker(gpu, args):
    args.rank += gpu
    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)
        
    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True

    model = get_model(gpu, args)

    # Loading data with type of transformation
    if args.transform_type == 'all':
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform(args.n_augmentations))
    elif args.transform_type == 'crop':
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform_crop(args.n_augmentations))
    elif args.transform_type == 'colorjitter':
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform_colorjitter(args.n_augmentations))
    elif args.transform_type == 'rotate':
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform_rotate(args.n_augmentations))
    elif args.transform_type == 'translate':
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform_translate(args.n_augmentations))
    elif args.transform_type == 'none':
        args.n_augmentations = 1
        dataset = ReturnIndexDataset(os.path.join(args.data_path, args.data_type), Transform_none(args.n_augmentations))


    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=False)
    
    print(f"Transforming with {args.transform_type} (x{args.n_augmentations})")
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size_per_gpu, shuffle=False,
            num_workers=0, sampler=sampler, collate_fn=partial(replicate_collator, n_augmentations=args.n_augmentations), 
            pin_memory=True)
    #args.workers
    start_time = time.time()
    encoder_feat = []
    labels = []
    indices = []
    with torch.no_grad():
        subsample = nn.AvgPool2d((4,4))
        for y, label, index in tqdm(loader):
            e = model.forward(y.cuda(gpu))
            encoder_feat.append(e.cpu())
            labels.append(label)
            indices.append(index)
            

        
    log_dir = os.path.join(args.log_dir, args.model_name, args.transform_type)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        
    filename = os.path.join(log_dir, f"{args.model_name}_{args.transform_type}_n_aug_{args.n_augmentations}_trained_{args.trained}_ddp_rank_{torch.distributed.get_rank()}.npz")
    print(f'-------------- Saving Encoder & Projector Embeddings to: {filename} --------------')
    np.savez_compressed(filename, backbone_features=torch.cat(encoder_feat, 0).numpy(), labels=torch.cat(labels, 0).numpy(), indices=torch.cat(indices, 0).numpy())
                
    print('Running Time --->: {:.2f} sec.'.format(time.time() - start_time))
 

def handle_sigusr1(signum, frame):
    os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
    exit()


def handle_sigterm(signum, frame):
    pass


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class Transform:
    def __init__(self, n_augmentations):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments

class Transform_crop:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments

class Transform_colorjitter:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.Resize((224,224), interpolation=Image.BICUBIC),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments




class Transform_rotate:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.Resize((224,224), interpolation=Image.BICUBIC),
            transforms.RandomRotation(90),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments

class Transform_translate:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.Resize((224,224), interpolation=Image.BICUBIC),
            transforms.RandomRotation(90),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments



class Transform_translate:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.Resize((224,224), interpolation=Image.BICUBIC),
            transforms.RandomAffine(0, translate=(.3,.3)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments


class Transform_none:
    def __init__(self, n_augmentations=2):
        self.transform = transforms.Compose([
            transforms.Resize((224,224), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.n_augmentations = n_augmentations

    def __call__(self, image):
        augments = []
        for _ in range(self.n_augmentations):
            augments.append(self.transform(image))
        return augments






class ReturnIndexDataset(torchvision.datasets.ImageFolder):
    def __getitem__(self, idx):
        img, label = super(ReturnIndexDataset, self).__getitem__(idx)
        return img, label, idx

def replicate_collator(batch, n_augmentations=2):
    output_data, output_label, output_data_idx = [], [], []
    for data, label, index in batch:
        for n_aug in range(n_augmentations):
            output_data.append(data[n_aug])
            output_label.append(torch.tensor(label))
            output_data_idx.append(torch.tensor(index))
      
    # print(f"{output_data[0].shape}, {len(output_label)}, {len(output_data_idx)}")
    return torch.stack(output_data), torch.stack(output_label), torch.stack(output_data_idx)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Feature Extraction')
    parser.add_argument('--model_name', default='supervised', type=str, help= 'Self-supervised model name: simclr-v1, simclr-v2, byol, swav,  \
                            deepcluster-v2, sela-v2 , moco-v1, moco-v2, cmc , infomin, insdis, pirl, pcl-v1, pcl-v2')
    #*****************************************************************************************************************************************
    parser.add_argument('--data_path', metavar='DIR', default ='/datasets/imagenet/'  ,
                         help='path to dataset')
    parser.add_argument('--pretrained_weights', default='/torch_hub/checkpoints/barlowtwins_checkpoint.pth',
                        metavar='DIR', help='path to checkpoint directory')   
    parser.add_argument('--log_dir', default= '/logs/NNK_augmentation_eval', 
                         help= 'directory to save')
    #*****************************************************************************************************************************************
    
    parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loader workers')
    parser.add_argument('--batch_size_per_gpu', default=4, type=int, help='Per-GPU batch-size')
    parser.add_argument('--batch_norm', action='store_true')
    parser.add_argument('--no-batch_norm', action='store_false')
    parser.set_defaults(feature=True)

    parser.add_argument('--bias', action='store_true')
    parser.add_argument('--no-bias', action='store_false')
    parser.set_defaults(feature=False)
    
    parser.add_argument('--projector', default='8192-8192-8192', type=str,
                        metavar='MLP', help='projector MLP')
    parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')

    parser.add_argument("--checkpoint_key", default="model", type=str, help='Key to use in the checkpoint (example: "teacher")')                     
    parser.add_argument('--label', default=0, type=int, help='label to be extracted')
                      
    parser.add_argument('--extract_layer', default = 'layer1', type= str,
                        help= 'additional layer to extract embedding can be conv1 , bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc') 
    parser.add_argument('--transform_type', default= 'all', type= str,
                        help= ' Type of transformation selected can be: none, all, crop, colorjitter, rotate')
    parser.add_argument('--n_augmentations', default=50, type=int, metavar='N',
                        help='number of total transformation for a sample')
    parser.add_argument('--data_type', default='val', type=str,
                        help=' type of data to extract: train, val')
    parser.add_argument('--trained', action='store_true')
    parser.add_argument('--no-trained', action='store_false')
    parser.set_defaults(feature=True)

                        
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()
    if 'SLURM_JOB_ID' in os.environ:
        # single-node and multi-node distributed training on SLURM cluster
        # requeue job on SLURM preemption
        signal.signal(signal.SIGUSR1, handle_sigusr1)
        signal.signal(signal.SIGTERM, handle_sigterm)
        # find a common host name on all nodes
        # assume scontrol returns hosts in the same order on all nodes
        cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
        stdout = subprocess.check_output(cmd.split())
        host_name = stdout.decode().splitlines()[0]
        args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node
        args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node
        args.dist_url = f'tcp://{host_name}:58472'
    else:
        # single-node distributed training
        args.rank = 0
        args.dist_url = 'tcp://localhost:58472'
        args.world_size = args.ngpus_per_node
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
