import os
import random
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as vision_transforms


def entropy(x):
    """Calculates entropy of a tensor.
    Args:
        x: Input tensor.
    Returns:
        Entropy of the input tensor.
    """
    return -torch.sum(x * torch.log(x + 1e-8), dim=1)

def cross_entropy(x, y):
    """Calculates cross entropy between two tensors.
    Args:
        x: Input tensor.
        y: Input tensor.
    Returns:
        Cross entropy between the input tensors.
    """
    return -torch.sum(x * torch.log(y + 1e-8), dim=1)

def set_seeds(seed: int = 0):
    """Sets random seeds of Python, NumPy and PyTorch.
    Args:
        seed: Seed value.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except:
        torch.set_deterministic(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


class ProxyModel(nn.Module):
    """All models should return ONLY ONE vector of (N, C) where C = number of classes.
    Considering most models in torchvision return one vector of (N,C), where
    N is the number of inputs and C is the number of classes, torchattacks also only
    supports limited forms of output. Please check the shape of the model’s output carefully."""

    def __init__(self, model, device):
        super(ProxyModel, self).__init__()
        self.model = model
        self.model.to(device)

        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.model(x)[0]


@torch.no_grad()
def show(original_inputs, fakes, original_labels, fake_labels, idx2label, figsize=(32, 12), save_figure=False,
         filename=None):

    n = original_inputs.shape[0]

    original_inputs = original_inputs if len(original_inputs.shape) == 4 else original_inputs.unsqueeze(dim=0)
    fake_imgs = fakes if len(fakes.shape) == 4 else fakes.unsqueeze(dim=0)

    fig, axs = plt.subplots(nrows=2, ncols=n, squeeze=False)
    fig.set_size_inches(*figsize)

    for i in range(n):
        original_img, fake_img = original_inputs[i], fake_imgs[i]

        original_img = vision_transforms.functional.to_pil_image(original_img)
        fake_img = vision_transforms.functional.to_pil_image(fake_img)
        original_label, fake_label = original_labels[i].item(), fake_labels[i].item()

        # Original image.
        axs[0, i].imshow(np.asarray(original_img))
        axs[0, i].set_title(f"Original prediction \n(label: {original_label}, \n'{idx2label[original_label]}')",
                            fontsize=8)
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

        # Fake image.
        axs[1, i].imshow(np.asarray(fake_img))
        axs[1, i].set_title(f"Adversarial prediction \n(label: {fake_label}, \n'{idx2label[fake_label]}')", fontsize=8)
        axs[1, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if save_figure:
        filename = filename if filename else ""
        timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        plt.savefig(f"{filename}_{timestamp}.pdf", dpi=200, bbox_inches="tight", transparent=True)

    plt.show()


def get_random_key(d):
    keys = list(d.keys())
    return keys[12]

def find_final_model(obj_list):
    for obj in obj_list:
        if obj.endswith(".torch") and "final" in obj:
            return obj
    return None

def name_parser(model_path):
    """Parses model metadata from the model path.
    Args:
        model_path: Path to the model.
    Returns:
        Dictionary with model name, dataset name, etc.
    """
    c = {}

    if 'oxford' in model_path:
        #last and second last are dataset name

        dataset_name = "_".join(model_path.split("_")[-2:])
        tune_technique = model_path.split("_")[-3]
        tune_mode = model_path.split("_")[-4]
        # model_name = "_".join(model_path.split("_")[0:-4])
    else:
        dataset_name = model_path.split("_")[-1]
        tune_technique = model_path.split("_")[-2]
        tune_mode = model_path.split("_")[-3]

    model_name = find_final_model(os.listdir(f"ats/tuned_models/{model_path}"))

    if model_name is None:
        print(f"Final model not found for {model_name}")
        return None

    c['dataset_name'] = dataset_name.strip()
    c['tune_mode'] = tune_mode
    c['tune_technique'] = tune_technique
    c['model_path'] = f'ats/tuned_models/{model_path}/{model_name}'

    return c


def get_config_from_name(c):
    # Configurations for fine-tuned models.
    cfg_finetune = ["config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=True"]

    # Configurations for fulltuned models.
    cfg_fulltune = ["config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=False",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False",
                    "config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=False"]
    
    
    dataset_name = c['dataset_name']
    tune_mode = c['tune_mode']
    tune_technique = c['tune_technique']
    model_path = c['model_path']

    model_name = model_path.split("/")[-2]

    if "alex" in model_path.lower():
        main_config = "eval_alexnet_8gpu_transfer_in1k_linear.yaml"

        if "colorization" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/alexnet=alexnet_colorization_linear_{tune_technique}.yaml"
        
        elif "jigsaw" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/alexnet=alexnet_jigsaw_linear_{tune_technique}.yaml"
        
    elif "resnet" in model_path.lower() or "rn50" in model_path.lower() or 'deepcluster' in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        if "deepcluster" in model_path.lower():
            remove_prefix = "module."
            key_name = ""
            if "fulltune" in tune_mode:
                append_prefix = "trunk._feature_blocks."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model._feature_blocks."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"
        
        elif "colorization" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_colorization_{tune_technique}.yaml"
        
        elif "jigsaw" in model_path.lower():
            remove_prefix = ""
            if "perm2k" in model_path.lower():
                key_name = "classy_state_dict"
            else:
                key_name = "model_state_dict"

            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"
        
        elif "swav" in model_path.lower() or "simclr" in model_path.lower() or \
            "npid" in model_path.lower() or "rotnet" in model_path.lower():
            remove_prefix = ""
            key_name = "classy_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"

    elif "npid" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."
        
        extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"

    elif "deit" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."
        
        extra_config = f"+config/models/{dataset_name.lower()}/deit=dino_deit_s16_{tune_technique}.yaml"
    
    elif "xcit" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."

        extra_config = f"+config/models/{dataset_name.lower()}/xcit=dino_xcit_s16_{tune_technique}.yaml"


    if 'pirl_jigsaw_4node' in model_path.lower():
        key_name = "classy_state_dict"

    if 'oxford' in dataset_name.lower():
        dataset_mode = "disk_folder"
    else:
        dataset_mode = "torchvision_dataset"

        

    if tune_mode == 'finetune':
        freeze_trunk = "True"
    elif tune_mode == 'fulltune':
        freeze_trunk = "False"

    # model_save_name = f"{model_name}_{tune_mode}_{tune_technique}_{dataset_name}"
    key_name = 'classy_state_dict'
    append_prefix = "trunk.base_model._feature_blocks."


    dataset_path = "ats/data/datasets/"
    if 'oxford_flowers' in dataset_name.lower():
        dataset_path += "oxford_flowers/"
    elif 'oxford_pets' in dataset_name.lower():
        dataset_path += "oxford_pets/"

    command = f'hydra.verbose=true config={main_config} ' \
            f'{extra_config} ' \
            f'config.DATA.TEST.DATA_PATHS=[{dataset_path}] ' \
            f'config.DATA.TEST.DATA_SOURCES=[{dataset_mode}] config.DATA.TEST.LABEL_SOURCES=[{dataset_mode}] config.DATA.TEST.DATASET_NAMES=[{dataset_name}] ' \
            f'config.DATA.TEST.BATCHSIZE_PER_REPLICA=64 config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=64 ' \
            f'config.DISTRIBUTED.NUM_NODES=1 config.DISTRIBUTED.NUM_PROC_PER_NODE=1 ' \
            f'config.MODEL.WEIGHTS_INIT.PARAMS_FILE="{model_path}" ' \
            f'config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME="{key_name}" ' \
            f'config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="{append_prefix}" config.MODEL.WEIGHTS_INIT.REMOVE_PREFIX="{remove_prefix}" ' \
    
    commands = command.strip().split(" ")
    
    if 'fulltune' in tune_mode:
        commands += cfg_fulltune
    elif 'finetune' in tune_mode:
        commands += cfg_finetune

    return commands