import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import open_clip
from open_clip.tokenizer import tokenize

import tqdm
import glob
import numpy as np
import os
from argparse import ArgumentParser
import kornia
import pickle

from imgnet_data import imagenet_templates as templates

device = 'cuda:0'

# define a class names
CLASS_NAMES = ['airplane', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'clock', 'dog', 'elephant', 'keyboard', 'knife', 'oven', 'truck']
EXTENDED_CLASS_NAMES = CLASS_NAMES + ['na']

NOISE_GENERALISATION_DATASETS = ["colour",
                                 "contrast",
                                 "high-pass",
                                 "low-pass",
                                 "phase-scrambling",
                                 "power-equalisation",
                                 "false-colour",
                                 "rotation",
                                 "eidolonI",
                                 "eidolonII",
                                 "eidolonIII",
                                 "uniform-noise"]


class ImageFolderWithPaths(datasets.ImageFolder):

    def __getitem__(self, index):
  
        img, label = super(ImageFolderWithPaths, self).__getitem__(index)
        
        path = self.imgs[index][0]
        
        return (img, label, path)

def mask_fourier_spectrum(imgs, mask):

    # get the fourier representation of an image
    fourier_spectrum = torch.fft.fft2(imgs)

    # cartesian --> polar. we only want to mask the amplitude
    amplitude = torch.abs(fourier_spectrum)
    phase = torch.angle(fourier_spectrum)

    # mask the amplitude with our mask
    amplitude = amplitude * mask

    # polar --> cartesian
    masked_spectrum = torch.polar(amplitude, phase)

    # return the real part of our result. we get complex numbers otherwise.
    return torch.clip(torch.real(torch.fft.ifft2(masked_spectrum)), 0., 1.)

def get_zeroshot_cls_matrix(model):

    mat = []
    with torch.no_grad():
        for category_name in CLASS_NAMES:
            prompts = [template.format(category_name) for template in templates]
            tokens = tokenize(prompts).to(device)
            txt_emb = model.encode_text(tokens, normalize=True)
            txt_emb = torch.mean(txt_emb, dim=0, keepdim=True)
            txt_emb = F.normalize(txt_emb, dim=-1)
            mat.append(txt_emb)
    mat = torch.cat(mat, dim=0)
    
    return mat

def prepare_fourier_mask(args, fourier_mask, img_size):

    half = img_size // 2

    if args.symmetric:
        # flip/copy the mask into four corners
        mask = torch.zeros((1, 1, img_size, img_size)).to(device)
        mask[:, :, :half, :half] = fourier_mask
        mask[:, :, half:, :half] = torch.flip(fourier_mask, [2])
        mask[:, :, :half, half:] = torch.flip(fourier_mask, [3])
        mask[:, :, half:, half:] = torch.flip(fourier_mask, [2, 3])
    else:
        mask = fourier_mask

    if args.blur_sigma > 0:
        # blur the fourier mask
        mask = kornia.filters.gaussian_blur2d(mask,
                                              kernel_size=(img_size-1, img_size-1),
                                              sigma=(args.blur_sigma, args.blur_sigma))

    return mask
        
def train_one_epoch(args, 
                    model, 
                    zeroshot_matrix, 
                    dataloader, 
                    optimizer, 
                    fourier_mask, 
                    temperature,
                    optimal_label_map):

    _range = tqdm.tqdm(enumerate(dataloader))

    losses = []

    image_mean = torch.tensor(model.visual.image_mean).view(1, 3, 1, 1).to(device)
    image_std = torch.tensor(model.visual.image_std).view(1, 3, 1, 1).to(device)

    for _, (img, tg, paths) in _range:

        optimizer.zero_grad()

        img = img.to(device)
        tg = tg.to(device)

        # prepare mask and preprocess images using the fourier mask
        mask = prepare_fourier_mask(args, fourier_mask, model.visual.image_size[0])
        img = mask_fourier_spectrum(img, mask)
        img = (img - image_mean) / image_std

        # get the embeddings and classification similarities
        img_embs = model.encode_image(img, normalize=True)
        logits = (img_embs @ zeroshot_matrix.T)
        
        # scale cosine sims from clip using temperature
        logits = logits / temperature

        if args.loss_fn == 'second':
            
            # compute the real and human labels from file names
            real_labels = tg
            human_labels = [pth.split('/')[-1] for pth in paths]
            human_labels = [pth.split('__')[0] for pth in human_labels]
            numeric_human_labels = torch.tensor([EXTENDED_CLASS_NAMES.index(human_label) for human_label in human_labels]).to(device)
            is_human_right = (numeric_human_labels==tg).long()
        
            # loss computation on images the human gets right
            loss1 = F.cross_entropy(logits, real_labels.long(), reduction='none')

            # loss computation on images the human gets wrong
            second_machine_preds = torch.argsort(logits, dim=-1).detach()
            first_machine_preds = second_machine_preds[:, -1]
            second_machine_preds = second_machine_preds[:, -2]
            target_label = first_machine_preds * (1 - (first_machine_preds == real_labels).long()) + second_machine_preds * (first_machine_preds == real_labels).long()
            loss2 = F.cross_entropy(logits, target_label.long(), reduction='none')
            
            # add losses
            cls_loss = torch.mean(loss1*is_human_right + loss2*(1-is_human_right))

        if args.loss_fn == 'optimal_labels':

            real_labels = tg

            # first prediction if it is wrong, otherwise second prediction
            machine_preds = torch.argsort(logits, dim=-1).detach()
            first_machine_preds = machine_preds[:, -1]
            second_machine_preds = machine_preds[:, -2]
            optimal_labels = first_machine_preds * (1 - (first_machine_preds == real_labels).long()) + second_machine_preds * (first_machine_preds == real_labels).long()

            # now check for optimal binary wrong/right and adjust optimal_labels to the real label accordingly
            # get labels from mapping
            img_names = [pth.split('/')[-1] for pth in paths]
            for img_name_idx in range(len(img_names)):
                img_name = img_names[img_name_idx]
                splits = img_name.split('__')
                corruption = splits[3]
                condition = splits[-1].split('_')[2]
                if corruption in NOISE_GENERALISATION_DATASETS:
                    img_id = '-'.join(splits[-1].split('_')[-2:])
                    if 'eidolon' in corruption:
                        condition = str(int(np.log2(float(condition.split('-')[0]))))
                    if corruption in ['high-pass', 'uniform-noise']:
                        condition = str(float(condition))
                        if condition[0] == '0':
                            condition = condition.rstrip('0')
                    if 'false-colour' in corruption:
                        condition = condition[0].upper() + condition[1:]
                else:
                    img_id = splits[-1].split('_')[-1]
                img_key = corruption+'_'+condition+'_'+img_id[:-len('.png')]
                correct = optimal_label_map[img_key]
                if correct==1:
                    optimal_labels[img_name_idx] = real_labels[img_name_idx]
                    
            optimal_labels = optimal_labels.long().reshape(-1,).to(device)
            cls_loss = F.cross_entropy(logits, optimal_labels)

        # l1 regularization
        l1_loss = torch.norm(fourier_mask, 1)

        total_loss = cls_loss + args.l1_lambda*l1_loss

        total_loss.backward()
        optimizer.step()

        _range.set_description(f'{cls_loss.item()} | {l1_loss.item()}')

        losses.append([cls_loss.item(), l1_loss.item(), total_loss.item()])

    return losses

def main(args):

    args.symmetric = bool(args.symmetric)
    args.learn_temperature = bool(args.learn_temperature)

    if args.loss_fn == 'optimal_labels':
        optimal_label_map = {}
        with open('./optimal_responses.pkl', 'rb') as f:
            _optimal_label_map = pickle.load(f)
        for k, v in _optimal_label_map.items():
            for k2, v2 in v.items():
                optimal_label_map.update(v2)
    else:
        optimal_label_map = None

    # load the model
    arch = args.model_name.split('__')[0]
    pretrain_dataset = args.model_name.split('__')[1]
    model, _, _ = open_clip.create_model_and_transforms(arch, pretrained=pretrain_dataset, cache_dir=args.clip_model_dir, device=device)
    model_resolution = model.visual.image_size[0]

    # setup data
    transform = transforms.Compose([
        transforms.ToTensor()])
    ddir = os.path.join(args.data_dir, 'train')
    image_dataset = ImageFolderWithPaths(ddir, transform=transform)
    dataloader = DataLoader(image_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8)

    # setup mask
    if args.symmetric:
        filter_size = model_resolution // 2
    else:
        filter_size = model_resolution

    # initialize fourier mask
    fourier_mask = torch.ones((1, 1, filter_size, filter_size), device=device)*args.init_constant
    fourier_mask += torch.randn(fourier_mask.shape, device=device)*np.sqrt(1e-5)
    fourier_mask.requires_grad = True

    # setup optimizer
    if args.learn_temperature:
        temperature = torch.tensor([args.temperature], dtype=torch.float, device=device, requires_grad=True)
        optimizer = torch.optim.Adam([fourier_mask, temperature], lr=args.lr)
    else:
        temperature = torch.tensor([args.temperature], dtype=torch.float, device=device, requires_grad=False)
        optimizer = torch.optim.Adam([fourier_mask], lr=args.lr)

    # get zeroshot matrix classifier
    zeroshot_matrix = get_zeroshot_cls_matrix(model)

    # setup save paths for mask and loss
    os.makedirs(args.save_dir, exist_ok=True)
    exp_tag = f'{args.model_name}__loss_fn--{args.loss_fn}__temp--{args.temperature}__lr--{args.lr}__symmetric--{int(args.symmetric)}__blur--{args.blur_sigma}__l1_lambda--{args.l1_lambda}__init_constant--{args.init_constant}__batch_size--{args.batch_size}__learn_temperature--{int(args.learn_temperature)}'
    mask_path = os.path.join(args.save_dir, exp_tag+'__mask.pt')
    losses_path = os.path.join(args.save_dir, exp_tag+'__losses.npy')

    # train the mask
    losses = []
    best_loss = np.inf
    for epoch in range(args.num_epochs):

        print(f'epoch: {epoch}')
        loss = train_one_epoch(args, model, zeroshot_matrix, dataloader, optimizer, fourier_mask, temperature, optimal_label_map)
        avg_total_loss = sum([l[-1] for l in loss]) / len(loss)
        losses+=loss
        if avg_total_loss <= best_loss:
            mask = prepare_fourier_mask(args, fourier_mask, model.visual.image_size[0])
            torch.save(mask.detach().cpu(), mask_path)
            best_loss = avg_total_loss
        print('')

    losses = np.asarray(losses)
    np.save(losses_path, losses)

if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--data_dir', help='location where mvh data for optimization is stored')
    parser.add_argument('--save_dir', help='location where results are saved')
    parser.add_argument('--clip_model_dir', default='./clip_models', help='cache dir for clip models')
    parser.add_argument('--model_name', default='ViT-H-14__laion2b_s32b_b79k', help='name of openclip model')
    parser.add_argument('--loss_fn', default='optimal_labels', help='what loss function to use.', choices=['optimal_labels', 'second'])
    parser.add_argument('--temperature', default=0.0183, type=float, help='temperature for the logits')
    parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
    parser.add_argument('--symmetric', type=int, default=1, help='constrain fourier filter to be symmetric')
    parser.add_argument('--blur_sigma', type=float, default=6.0, help='blur regularization for the optimizing the filter')
    parser.add_argument('--l1_lambda', type=float, default=0.0001, help='l1 regularization strength')
    parser.add_argument('--num_epochs', type=int, default=100, help='number of training epochs')
    parser.add_argument('--init_constant', type=float, default=1., help='what to initialize filter to')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size for optimization')
    parser.add_argument('--learn_temperature', type=int, default=0, help='learn temperature or not')
    args = parser.parse_args()
    main(args)