import math
import warnings

import torch
import numpy as np
import os
import json
from tqdm import tqdm
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import Pool
import torch.nn.functional as F


def load_problem(name):
    from problems import PDTRP, PDCVRP, PDTRPTW, PDCVRPTW
    problem = {
        'pdtrp': PDTRP,
        'pdcvrp': PDCVRP,
        'pdtrptw': PDTRPTW,
        'pdcvrptw': PDCVRPTW
    }.get(name, None)
    assert problem is not None, "Currently unsupported problem: {}!".format(name)
    return problem


def torch_load_cpu(load_path):
    return torch.load(load_path, map_location=lambda storage, loc: storage, weights_only=False)  # Load on CPU


def move_to(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    return var.to(device)


def _load_model_file(load_path, model):
    """Loads the model with parameters from the file and returns optimizer state dict if it is in the file"""

    # Load the model parameters from a saved state
    load_optimizer_state_dict = None
    print('\nLoading model from {}'.format(load_path))

    load_data = torch.load(
        os.path.join(
            os.getcwd(),
            load_path
        ), map_location=lambda storage, loc: storage, weights_only=False)

    if isinstance(load_data, dict):
        load_optimizer_state_dict = load_data.get('optimizer', None)
        load_model_state_dict = load_data.get('model', load_data)
    else:
        load_model_state_dict = load_data.state_dict()

    state_dict = model.state_dict()

    state_dict.update(load_model_state_dict)

    model.load_state_dict(state_dict)

    return model, load_optimizer_state_dict


def load_args(filename):
    with open(filename, 'r') as f:
        args = json.load(f)

    # Backwards compatibility
    if 'data_distribution' not in args:
        args['data_distribution'] = None
        probl, *dist = args['problem'].split("_")
        if probl == "op":
            args['problem'] = probl
            args['data_distribution'] = dist[0]
    
    if 'knn_strat' not in args:
        args['knn_strat'] = None
        
    if 'aggregation_graph' not in args:
        args['aggregation_graph'] = "mean"

    if args['problem'] == 'tsp':
        warnings.warn("TSP problem is deprecated, please use PDTRP instead.")
        args['problem'] = 'pdtrp'
    
    return args


def load_model(path, epoch=None, extra_logging=False, route_id=None):
    from nets.dynamic_attention_model import DynamicAttentionModel
    from nets.encoders.gnn_encoder import GNNEncoder
    
    if os.path.isfile(path):
        model_filename = path
        path = os.path.dirname(model_filename)
    elif os.path.isdir(path):
        if epoch is None:
            epoch = max(
                int(os.path.splitext(filename)[0].split("-")[1])
                for filename in os.listdir(path)
                if os.path.splitext(filename)[1] == '.pt'
            )
        model_filename = os.path.join(path, 'epoch-{}.pt'.format(epoch))
    else:
        assert False, "{} is not a valid directory or file".format(path)

    args = load_args(os.path.join(path, 'args.json'))

    problem = load_problem(args['problem'])

    if route_id is not None:
        problem = problem()
        problem.get_route_info(route_id, device='cpu')
    
    model = DynamicAttentionModel(
        problem=problem,
        embedding_dim=args['embedding_dim'],
        encoder_class=GNNEncoder,
        n_encode_layers=args['n_encode_layers'],
        aggregation=args['aggregation'],
        aggregation_graph=args['aggregation_graph'],
        normalization=args['normalization'],
        learn_norm=args['learn_norm'],
        track_norm=args['track_norm'],
        gated=args['gated'],
        n_heads=args['n_heads'],
        tanh_clipping=args['tanh_clipping'],
        mask_inner=True,
        mask_logits=True,
        mask_graph=False,
        checkpoint_encoder=args['checkpoint_encoder'],
        shrink_size=args['shrink_size'],
        edge_features=args['edge_features'],
        use_time_feature=args['use_time_feature'],
        functional_time_encoding=args['functional_time_encoding'],
        knn_strat=args['knn_strat'],
        neighbors=args['neighbors'],
        scale_times=args['scale_times'],
        use_arrival_lstm=args['use_arrival_lstm'],
        use_arrival_times=args['use_arrival_times'],
        recursively_remove_visited_nodes=args['recursively_remove_visited_nodes'],
        extra_logging=extra_logging,
    )    
    
    # Overwrite model parameters by parameters to load
    load_data = torch_load_cpu(model_filename)
    model.load_state_dict({**model.state_dict(), **load_data.get('model', {})})

    model, *_ = _load_model_file(model_filename, model)

    model.eval()  # Put in eval mode

    return model, args

def load_model_static(path, epoch=None, extra_logging=False):
    from nets.attention_model import AttentionModel
    from nets.encoders.gnn_encoder import GNNEncoder
    
    if os.path.isfile(path):
        model_filename = path
        path = os.path.dirname(model_filename)
    elif os.path.isdir(path):
        if epoch is None:
            epoch = max(
                int(os.path.splitext(filename)[0].split("-")[1])
                for filename in os.listdir(path)
                if os.path.splitext(filename)[1] == '.pt'
            )
        model_filename = os.path.join(path, 'epoch-{}.pt'.format(epoch))
    else:
        assert False, "{} is not a valid directory or file".format(path)

    args = load_args(os.path.join(path, 'args.json'))

    problem = load_problem(args['problem'])
    
    model_class = {
        'attention': AttentionModel,
        # 'pointer': PointerNetwork
    }.get(args.get('model', 'attention'), None)
    assert model_class is not None, "Unknown model: {}".format(model_class)
    encoder_class = {
        'gnn': GNNEncoder,
    }.get(args.get('encoder', 'gnn'), None)
    assert encoder_class is not None, "Unknown encoder: {}".format(encoder_class)
    model = model_class(
        problem=problem,
        embedding_dim=args['embedding_dim'],
        encoder_class=encoder_class,
        n_encode_layers=args['n_encode_layers'],
        aggregation=args['aggregation'],
        aggregation_graph=args['aggregation_graph'],
        normalization=args['normalization'],
        learn_norm=args['learn_norm'],
        track_norm=args['track_norm'],
        gated=args['gated'],
        n_heads=args['n_heads'],
        tanh_clipping=args['tanh_clipping'],
        mask_inner=True,
        mask_logits=True,
        mask_graph=False,
        checkpoint_encoder=args['checkpoint_encoder'],
        shrink_size=args['shrink_size'],
        extra_logging=extra_logging
    )    
    
    # Overwrite model parameters by parameters to load
    load_data = torch_load_cpu(model_filename)
    model.load_state_dict({**model.state_dict(), **load_data.get('model', {})})

    model, *_ = _load_model_file(model_filename, model)

    model.eval()  # Put in eval mode

    return model, args

def parse_softmax_temperature(raw_temp):
    # Load from file
    if os.path.isfile(raw_temp):
        return np.loadtxt(raw_temp)[-1, 0]
    return float(raw_temp)


def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True):
    # # Test
    # res = func((directory, 'test', *dataset[0]))
    # return [res]

    num_cpus = os.cpu_count() if opts.cpus is None else opts.cpus

    w = len(str(len(dataset) - 1))
    offset = getattr(opts, 'offset', None)
    if offset is None:
        offset = 0
    ds = dataset[offset:(offset + opts.n if opts.n is not None else len(dataset))]
    pool_cls = (Pool if use_multiprocessing and num_cpus > 1 else ThreadPool)
    with pool_cls(num_cpus) as pool:
        results = list(tqdm(pool.imap(
            func,
            [
                (
                    directory,
                    str(i + offset).zfill(w),
                    *problem
                )
                for i, problem in enumerate(ds)
            ]
        ), total=len(ds), mininterval=opts.progress_bar_mininterval, ascii=True))

    failed = [str(i + offset) for i, res in enumerate(results) if res is None]
    assert len(failed) == 0, "Some instances failed: {}".format(" ".join(failed))
    return results, num_cpus


def do_batch_rep(v, n):
    if isinstance(v, dict):
        return {k: do_batch_rep(v_, n) for k, v_ in v.items()}
    elif isinstance(v, list):
        return [do_batch_rep(v_, n) for v_ in v]
    elif isinstance(v, tuple):
        return tuple(do_batch_rep(v_, n) for v_ in v)

    return v[None, ...].expand(n, *v.size()).contiguous().view(-1, *v.size()[1:])


def sample_many(inner_func, get_cost_func, input, batch_rep=1, iter_rep=1):
    input = do_batch_rep(input, batch_rep)

    costs = []
    pis = []
    for i in range(iter_rep):
        _log_p, pi = inner_func(input)
        # pi.view(-1, batch_rep, pi.size(-1))
        cost, mask = get_cost_func(input, pi)

        costs.append(cost.view(batch_rep, -1).t())
        pis.append(pi.view(batch_rep, -1, pi.size(-1)).transpose(0, 1))

    max_length = max(pi.size(-1) for pi in pis)
    # (batch_size * batch_rep, iter_rep, max_length) => (batch_size, batch_rep * iter_rep, max_length)
    pis = torch.cat(
        [F.pad(pi, (0, max_length - pi.size(-1))) for pi in pis],
        1
    )  # .view(embeddings.size(0), batch_rep * iter_rep, max_length)
    costs = torch.cat(costs, 1)

    # (batch_size)
    mincosts, argmincosts = costs.min(-1)
    # (batch_size, minlength)
    minpis = pis[torch.arange(pis.size(0), out=argmincosts.new()), argmincosts]

    return minpis, mincosts


def get_best(sequences, cost, ids=None, batch_size=None):
    """
    Ids contains [0, 0, 0, 1, 1, 2, ..., n, n, n] if 3 solutions found for 0th instance, 2 for 1st, etc
    :return: list with n sequences and list with n lengths of solutions
    """
    if ids is None:
        idx = cost.argmin()
        return sequences[idx:idx+1, ...], cost[idx:idx+1, ...]

    splits = np.hstack([0, np.where(ids[:-1] != ids[1:])[0] + 1])
    mincosts = np.minimum.reduceat(cost, splits)

    group_lengths = np.diff(np.hstack([splits, len(ids)]))
    all_argmin = np.flatnonzero(np.repeat(mincosts, group_lengths) == cost)
    result = np.full(len(group_lengths) if batch_size is None else batch_size, -1, dtype=int)

    result[ids[all_argmin[::-1]]] = all_argmin[::-1]

    return [sequences[i] if i >= 0 else None for i in result], [cost[i] if i >= 0 else math.inf for i in result]

def scale_time(time, time_horizon):
    """
    Scale time to roughly a [0, 1] range by dividng by the time horizon
    :param time: time value
    :param time_horizon: maximum time horizon
    :return: scaled time value
    """
    if isinstance(time, torch.Tensor):
        return time / time_horizon
    return np.array(time) / time_horizon if isinstance(time, np.ndarray) else float(time) / time_horizon

def count_trailing_zeros(tensor):
    tensor = tensor.flatten()  # Ensure 1D
    reversed_tensor = tensor.flip(0)
    zero_mask = (reversed_tensor == 0)
    # Find first non-zero from the end
    nonzero_idx = (~zero_mask).nonzero(as_tuple=True)[0]
    return int(nonzero_idx[0]) if nonzero_idx.numel() > 0 else len(tensor)