from PIL import Image
import random
import os
import argparse
import inspect
from datetime import datetime

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter

class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]

class AverageMeter(object):

    def __init__(self, name='', fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def floatTensor2Image(float_tensor):
    # Convert FloatTensor to numpy array
    numpy_array = float_tensor.numpy()
    # Assuming the tensor values are in the range [-1, 1], rescale them to [0, 255]
    rescaled_array = ((numpy_array + 1.0) * 127.5).astype(np.uint8)
    # Create PIL.Image from the numpy array
    image = Image.fromarray(rescaled_array.transpose(1, 2, 0))  # Transpose to channel-last format
    return image

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def strip_state_dict(state_dict, strip_key='module.'):

    """
    Strip 'module' from start of state_dict keys
    Useful if model has been trained as DataParallel model
    """

    for k in list(state_dict.keys()):
        if k.startswith(strip_key):
            state_dict[k[len(strip_key):]] = state_dict[k]
            del state_dict[k]

    return state_dict


def get_dino_head_weights(pretrain_path):

    """
    :param pretrain_path: Path to full DINO pretrained checkpoint as in https://github.com/facebookresearch/dino
     'full_ckpt'
    :return: weights only for the projection head
    """

    all_weights = torch.load(pretrain_path)

    head_state_dict = {}
    for k, v in all_weights['teacher'].items():
        if 'head' in k and 'last_layer' not in k:
            head_state_dict[k] = v

    head_state_dict = strip_state_dict(head_state_dict, strip_key='head.')

    # Deal with weight norm
    weight_norm_state_dict = {}
    for k, v in all_weights['teacher'].items():
        if 'last_layer' in k:
            weight_norm_state_dict[k.split('.')[2]] = v

    linear_shape = weight_norm_state_dict['weight'].shape
    dummy_linear = torch.nn.Linear(in_features=linear_shape[1], out_features=linear_shape[0], bias=False)
    dummy_linear.load_state_dict(weight_norm_state_dict)
    dummy_linear = torch.nn.utils.weight_norm(dummy_linear)

    for k, v in dummy_linear.state_dict().items():

        head_state_dict['last_layer.' + k] = v

    return head_state_dict

def transform_moco_state_dict(obj, num_classes):

    """
    :param obj: Moco State Dict
    :param args: argsparse object with training classes
    :return: State dict compatable with standard resnet architecture
    """

    newmodel = {}
    for k, v in obj.items():
        if not k.startswith("module.encoder_q."):
            continue
        old_k = k
        k = k.replace("module.encoder_q.", "")

        if k.startswith("fc.2"):
            continue

        if k.startswith("fc.0"):
            k = k.replace("0.", "")
            if "weight" in k:
                v = torch.randn((num_classes, v.size(1)))
            elif "bias" in k:
                v = torch.randn((num_classes,))

        newmodel[k] = v

    return newmodel


def init_experiment(args, runner_name=None, exp_id=None):
    if runner_name is None:
        runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
    root_dir = os.path.join(args.exp_root, *runner_name)
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    # Either generate a unique experiment ID, or use one which is passed
    if exp_id is None:
        # Unique identifier for experiment
        now = '{:02d}-{:02d}_{:02d}-{:02d}_{}'.format(
            datetime.now().day, datetime.now().month, datetime.now().hour, datetime.now().minute, args.dataset_name)
        log_dir = os.path.join(root_dir, 'log', now)
        while os.path.exists(log_dir):
            now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
                  datetime.now().strftime("%S.%f")[:-3] + ')'
            log_dir = os.path.join(root_dir, 'log', now)
    else:
        log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    args.log_dir = log_dir

    # Instantiate directory to save models and training logs to
    model_root_dir = os.path.join(args.log_dir, 'checkpoints')
    if not os.path.exists(model_root_dir):
        os.mkdir(model_root_dir)
    args.model_path = os.path.join(model_root_dir, 'model.pt')
    args.writer = SummaryWriter(log_dir=args.log_dir)

    print(f'Experiment saved to: {args.log_dir}')
    print(f'Runner name: {runner_name}')

    return args