import copy
import json
import os
import random
from pathlib import Path
import utils.tensor_utils as TensorUtils
import numpy as np
import torch
import torch.nn as nn
import warnings
from natsort import natsorted
import datetime

def get_experiment_dir(cfg, evaluate=False, allow_overlap=False):

    prefix = cfg.output_prefix
    if evaluate:
        prefix = os.path.join(prefix, 'evaluate')

    experiment_dir = (
            f"{prefix}/{cfg.task.suite_name}/{cfg.task.benchmark_name}/"
            + f"{cfg.algo.name}/{cfg.exp_name}"
    )
    if cfg.variant_name is not None:
        experiment_dir += f'/{cfg.variant_name}'

    if cfg.seed != 10000:
        experiment_dir += f'/{cfg.seed}'

    if cfg.make_unique_experiment_dir:
        timestamp = datetime.datetime.now().strftime("%Y.%m.%d_%H-%M-%S")
        experiment_dir += f"/run_{timestamp}"

    else:
        experiment_dir += f'/stage_{cfg.stage}'

        if not allow_overlap and not cfg.training.resume:
            assert not os.path.exists(experiment_dir), \
                f'cfg.make_unique_experiment_dir=false but {experiment_dir} is already occupied'

    experiment_name = "_".join(experiment_dir.split("/")[len(cfg.output_prefix.split('/')):])
    return experiment_dir, experiment_name


def get_latest_checkpoint(checkpoint_dir):
    if os.path.isfile(checkpoint_dir):
        return checkpoint_dir

    onlyfiles = [f for f in os.listdir(checkpoint_dir) if os.path.isfile(os.path.join(checkpoint_dir, f))]
    onlyfiles = natsorted(onlyfiles)
    best_file = onlyfiles[-1]
    return os.path.join(checkpoint_dir, best_file)


def soft_load_state_dict(model, loaded_state_dict):

    current_model_dict = model.state_dict()
    new_state_dict = {}

    for k in current_model_dict.keys():
        if k in loaded_state_dict:
            v = loaded_state_dict[k]
            if not hasattr(v, 'size') or v.size() == current_model_dict[k].size():
                new_state_dict[k] = v
            else:
                warnings.warn(f'Cannot load checkpoint parameter {k} with shape {loaded_state_dict[k].shape}'
                              f'into model with corresponding parameter shape {current_model_dict[k].shape}. Skipping')
                new_state_dict[k] = current_model_dict[k]
        else:
            new_state_dict[k] = current_model_dict[k]
            warnings.warn(f'Model parameter {k} does not exist in checkpoint. Skipping')
    for k in loaded_state_dict.keys():
        if k not in current_model_dict:
            warnings.warn(f'Loaded checkpoint parameter {k} does not exist in model. Skipping')

    model.load_state_dict(new_state_dict)


def map_tensor_to_device(data, device):
    """Move data to the device specified by device."""
    return TensorUtils.map_tensor(
        data, lambda x: safe_device(x, device=device)
    )

def process_inputs(device, dtype, inputs):

    def process_input(input_data):
        if isinstance(input_data, torch.Tensor):
            return input_data.to(device, dtype)
        elif isinstance(input_data, (list, tuple)):
            return type(input_data)(process_input(x) for x in input_data)
        elif isinstance(input_data, dict):
            return {k: process_input(v) for k, v in input_data.items()}
        else:
            return input_data
    return {k: process_input(v) for k, v in inputs.items()}

def process_outputs(key_output, **outputs):
    def process_output(output_data):
        if isinstance(output_data, torch.Tensor):
            return output_data.detach().cpu()
        elif isinstance(output_data, (list, tuple)):
            return type(output_data)(process_output(x) for x in output_data)
        elif isinstance(output_data, dict):
            return {k: process_output(v) for k, v in output_data.items()}
        else:
            return output_data
    return process_output(key_output), {k: process_output(v) for k, v in outputs.items()}

def safe_device(x, device="cpu"):
    if device == "cpu":
        return x.cpu()
    elif "cuda" in device:
        if torch.cuda.is_available():
            return x.to(device)
        else:
            return x.cpu()


def extract_state_dicts(inp):
    if not (isinstance(inp, dict) or isinstance(inp, list)):
        if hasattr(inp, 'state_dict'):
            return inp.state_dict()
        else:
            return inp
    elif isinstance(inp, list):
        out_list = []
        for value in inp:
            out_list.append(extract_state_dicts(value))
        return out_list
    else:
        out_dict = {}
        for key, value in inp.items():
            out_dict[key] = extract_state_dicts(value)
        return out_dict


def save_state(state_dict, path):
    save_dict = extract_state_dicts(state_dict)
    torch.save(save_dict, path)


def load_state(path):
    return torch.load(path, map_location='cuda:0')  # 强制加载到GPU0


def torch_save_model(model, optimizer, scheduler, model_path, cfg=None):
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "cfg": cfg,
        },
        model_path,
    )


def torch_load_model(model_path):
    checkpoint = torch.load(model_path)
    return checkpoint["model_state_dict"], checkpoint["optimizer_state_dict"], checkpoint["scheduler_state_dict"], \
    checkpoint["cfg"]
