import argparse
import torch
import os
import random
from torch.utils.data import Subset
import logging

logging.getLogger("PIL").setLevel(logging.WARNING)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_txt(path :str) -> list:
    return [line.rstrip('\n') for line in open(path)]

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
    
def setup_logging(checkpoint_path):
    """Setup logging to file and console with standard format."""
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(checkpoint_path, 'record.log')),
            logging.StreamHandler()
        ]
    )
    return logger

def create_experiment_directories(dataset, method, experiment_id):
    """Create necessary directories for saving checkpoints and logs."""
    base_path = os.path.join('./checkpoint', dataset)
    method_path = os.path.join(base_path, method)
    check_path = os.path.join(method_path, experiment_id)
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(method_path, exist_ok=True)
    os.makedirs(check_path, exist_ok=True)
    return check_path

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
    
def shuffle_labels(label):
    max_val = torch.max(label).item()
    shuffled = torch.randint(0, max_val + 1, label.size()).to(device)
    shuffled[label == shuffled] = (shuffled[label == shuffled] + 1) % (max_val + 1)
    return shuffled

def extract_subset(dataset, num_subset :int, random_subset :bool):
    if random_subset:
        random.seed(0)
        indices = random.sample(list(range(len(dataset))), num_subset)
    else:
        indices = [i for i in range(num_subset)]
    return Subset(dataset, indices)

def record_path_words(record_path, record_words):
    print(record_words)
    with open(record_path, "a+") as f:
        f.write(record_words)
    f.close()
    return 

def format_time(seconds):
    """
    cur_time = time.time()
    time.sleep(64)
    last_time = time.time()
    step_time = format_time(last_time-cur_time)
    print(step_time)
    """
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    return f

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['model'].load_state_dict(loaded_state['model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['step'] = loaded_state['step']

def diff2clf(x, is_imagenet=False): 
    # [-1, 1] to [0, 1]
    return (x / 2) + 0.5 

def clf2diff(x):
    # [0, 1] to [-1, 1]
    return (x - 0.5) * 2


