import src.attacks.affine as a
import src.attacks.affine_linspace as lin
import numpy as np
import torch
import torchvision
import os

import src.data.IS.imagenet_s.imagenet_s as i_s
import kornia.geometry.transform as t
from torchvision import transforms

a.config = a.config_imagenet
a.config['target_size'] = 224
a.config['crop_size'] =  224


tt = transforms.ToTensor()

imagenet_root_path = 'IMAGENET ROOT HERE'

params = i_s.get_param('300')
num_classes = params['num_classes']

name_list = os.path.join('IMAGENET S ROOT HERE/names', params['names'])
dataset_root = 'IMAGENET S ROOT HERE'

subdir = 'validation-segmentation'
gt_dir = os.path.join(dataset_root, params['dir'], subdir)
dataset = i_s.ImageNetSEvalDataset(imagenet_root_path, gt_dir, name_list, transform=tt,
    use_new_labels=True, simple_items=True, prefilter_items=True,
    transform_mask_to_img_classes=True)

import src.data.statistics as stat

label_map = {}
for i in range(len(dataset)):
    I, gt_uint, in_label = dataset[i]
    label_map[i] = in_label

stat.label_map = label_map


eval_device = torch.device('cuda:0')

#mean = torch.tensor([0.4914, 0.4822, 0.4465]).reshape((1,3,1,1))
#std = torch.tensor([0.2471, 0.2435, 0.2616]).reshape((1,3,1,1))

state_dict = {}

import gc
import math

torch.set_num_threads(torch.get_num_threads() // 2)

import random

import tqdm


mae_results = {}

import pickle

def save():
    dump_file = './data/imagenet/shift_happens.p'
    if os.path.exists(dump_file):
        os.remove(dump_file)
        #print("removed existing!")
    pickle.dump(mae_results, open(dump_file, 'wb'))

#import numpy.typing as npt
import src.attacks.affine as a
from skimage.transform import resize

num_elems = 10000000

data = []
idx_list = list(range(len(dataset)))
random.shuffle(idx_list)
for i in idx_list:
    try:
        I, gt_uint, in_label = dataset[i]
        loaded = a.load_and_pad_imagenet(I, gt_uint, in_label)
        img, m, b, c, start_end_coord = loaded
        def get_h_w(coords):
            return coords[0][1] - coords[0][0], coords[1][1] - coords[1][0]
        h_b, w_b = get_h_w(b)
        h_c, w_c = get_h_w(c)
        if not ((h_b <= h_c) and (w_b <= w_c)):
            loaded = a.load_and_pad_imagenet(I, gt_uint, in_label)
        assert (h_b <= h_c) and (w_b <= w_c)
        element = a.to_torch(*loaded)
        imgs, masks, bounds, crops, start_coords = element
        zooms_bounds, _ = a.get_zoom_bounds(crops, start_coords)
        eff_zoom = a.get_zooms(bounds, crops)
        data.append((element, in_label,i))
        if len(data) > num_elems:
            break
            #pass
    except AssertionError as e:
        raise e
    except:
        pass


save_dir = './data/imagenet/shift_happens'
from pathlib import Path
Path(save_dir).mkdir(parents=False, exist_ok=True)
#save_dir = None

def print_base_worst(model, data, name, mode):
    worst = stat.adaptive_worst_case(data)
    base = stat.adaptive_base_case(data, mode)
    error_ratio = (1 - worst) / (1 - base)
    print(f"{model} with {name}: base case {100*base:.1f}, worst case {100*worst:.1f}, diff {100*(base - worst):.1f}, error ratio {error_ratio:.2f}")

def eval_model(model, model_name, batch_size_model):
    
    args = (model, model_name, data, eval_device, batch_size_model)

    results_rotation = lin.rotation_linspace(*args, batch_size_rotation=500, resolution=500, do_resize=False, save_dir=save_dir)

    print_base_worst(model, results_rotation, "rotation", "rotation")

    results_rotation_filtered = [d for d in results_rotation if stat.gt_30(d)]

    print_base_worst(model, results_rotation_filtered, "rotation deg>=30", "rotation")

    results_translation = lin.translation_linspace(*args, batch_size_translation=700, resolution=250, save_dir=save_dir)

    print_base_worst(model, results_translation, "translation", "trans")

    results_translation_filtered = [d for d in results_translation if stat.max_freedom(d) > 100]

    print_base_worst(model, results_translation_filtered, "translation freed.>=100", "rotation")

    mae_results[model_name] = {
        "results_rotated" : results_rotation,
        "results_translation": results_translation
    }


shmodels_dict = {
    "vgg16": torchvision.models.vgg16(pretrained=True),
    "resnet18": torchvision.models.resnet18(pretrained=True),
    "resnet50": torchvision.models.resnet50(pretrained=True),
}

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

success = []
for i, (name, model) in enumerate(shmodels_dict.items()):
    b_size = 600
    try:
        m = model
        m.eval()
        m_device = m.to(eval_device)
        model_fun = lambda inp: m_device(normalize(inp))
        eval_model(model_fun, name, b_size)
        print(f"finished model {i}/{len(shmodels_dict)}: {name}")
        save()
        success.append(name)
    except Exception as e:
        import traceback
        traceback.print_exc()
        print(str(e))
        pass

print(success)