import torch
import numpy as np
import random
import os
import argparse as arg
import preprocess as pre
from model import *
import loss as L
from train import *
import json


def load_model(args):
    predictor = model_choice(args.predictor, depth=4, args=args, n_agents=args.num_classes).to(args.device)
    checkpoints = torch.load(f'Models_weights/{args.dataset}/{args.load_predictor}')
    predictor.load_state_dict(checkpoints['state_dict'])
    return predictor


def top_k_random_allocation(experts, k):
    """
    Randomly orders the experts and selects the top k.

    Args:
        experts (list): List of available expert models.
        k (int): Number of experts to select.

    Returns:
        list: A randomly chosen subset of experts of size k,
              determined by ranking experts with random scores.
    """
    # Assign a random score to each expert and sort them in descending order.
    experts_shuffled = sorted(experts, key=lambda _: random.random(), reverse=True)
    return experts_shuffled[:k]


def random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return


def compute_scores_allocation(costs_agents, epsilon=1e-6, normalize=False):
    """
    Computes a scores allocation from costs with random tie-breaking.

    Args:
        costs_agents (torch.Tensor): Tensor of costs with shape (n_agents, N)
                                     where N is the number of queries.
        epsilon (float): A small constant to scale the random noise.
        normalize (bool): If True, normalize scores so that the best expert gets 1
                          and the worst gets 0.

    Returns:
        id_deferral (torch.Tensor): Sorted indices tensor with shape (N, n_agents),
                                    where each row contains the agent indices sorted
                                    by ascending cost (with ties broken randomly).
        scores_allocation (torch.Tensor): Tensor of scores corresponding to the ranking,
                                          with shape (N, n_agents). Lower scores indicate
                                          better ranking. If `normalize=True`, the scores
                                          are scaled between 0 and 1.
    """
    # Transpose costs to shape (N, n_agents) so each row corresponds to a query.
    noisy_costs = costs_agents.T + torch.rand_like(costs_agents.T) * epsilon

    # Compute id_deferral: sort each row in ascending order of cost.
    id_deferral = torch.argsort(noisy_costs, dim=1, descending=False)

    # Create a tensor of rank indices: 0 for best, 1 for second best, etc.
    N, n_agents = id_deferral.shape
    ranks = torch.arange(n_agents, device=costs_agents.device).unsqueeze(0).expand(N, -1).float()

    # Scatter the ranks into a new tensor so that each agent gets its corresponding rank.
    scores_allocation = torch.empty_like(ranks)
    scores_allocation.scatter_(1, id_deferral, ranks)

    # Optionally normalize: best expert (rank 0) gets 1, worst gets 0.
    if normalize and n_agents > 1:
        scores_allocation = 1 - scores_allocation / (n_agents - 1)

    return id_deferral.T, scores_allocation.T

def setup_args():
    parser = arg.ArgumentParser(description='L2D robust')
    parser.add_argument('--beta', nargs='+', default=[0.2, 0.2, 0.1, 0.1], help='beta')
    parser.add_argument('--n_agents', type=int, default=5, help='lambda_2')
    parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
    parser.add_argument('--batch_size_eval', type=int, default=16, help='batch_size')
    parser.add_argument('--seed', type=int, default=42, help='batch_size')
    parser.add_argument('--epochs', type=int, default=1, help='batch_size')
    parser.add_argument('--epochs_cardinality', type=int, default=10, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3, help='lr')
    parser.add_argument('--name_run', type=str, default='top_k', help='batch_size')
    parser.add_argument('--name_project', type=str, default='top_k_testing', help='batch_size')
    parser.add_argument('--dataset', type=str, default='cifar100', help='mnist, cifar10, cifar100, pascal')
    parser.add_argument('--log_freq', type=int, default=10, help='evaluation frequency')
    parser.add_argument('--predictor', type=str, default='resnet', help='cnn/wideresnet/resnet18/resnet4')
    parser.add_argument('--load_predictor', type=str, default='model_resnet4_finalresnet-4.pth', help='loop=1, vectorized=0')
    parser.add_argument('--model', type=str, default='resnet', help='cnn/wideresnet/resnet18/resnet4_rej/MLP')
    parser.add_argument('--optimizer', type=str, default='Adam', help='AdamW/Adam/SGD')
    parser.add_argument('--dropout', type=float, default=0.0, help='vanilla/robust')
    parser.add_argument('--scheduler', type=str, default='None', help='linear/cosine/cyclelr')
    parser.add_argument('--subset_test', type=int, default=1, help='subset:1')
    parser.add_argument('--n_points', type=int, default=20, help='loop=1, vectorized=0')
    parser.add_argument('--device', type=str, default='cuda', help='overfit')
    parser.add_argument('--top_k', type=int, default=1, help='overfit')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='attack:1')
    parser.add_argument('--task', type=str, default='classification', help='classification/regression/multi_task')
    parser.add_argument('--alpha', type=int, default=1, help='alpha')
    parser.add_argument('--depth', type=int, default=16, help='alpha')
    parser.add_argument('--Lambda', type=float, default=0.01, help='alpha')
    parser.add_argument('--cardinality', type=int, default=1, help='alpha')

    args = parser.parse_args()

    args.beta = [float(i) for i in args.beta]
    if len(args.beta) != (args.n_agents-1): raise ValueError('beta must have the same length as n_agents')

    return args


def eval_random_optimal(test_loader, args, criterion, random_=False):
    N = len(test_loader.dataset)
    K = args.n_agents + args.num_classes
    agents_ = torch.tensor(test_loader.dataset.agents.T, device='cuda')

    if args.task == 'regression':
        labels = torch.tensor(test_loader.dataset.targets, device='cuda').squeeze(1)
        if random_:
            model_l =  (torch.rand(N, 1) * args.upper).squeeze(1)
        else:
            model_l = labels
    else:
        labels = torch.tensor(test_loader.dataset.targets, device='cuda')
        model_l = (torch.arange(args.num_classes, device='cuda')).repeat(len(labels), 1).to('cuda')

    agents = torch.concatenate((model_l, agents_), dim=1)
    if random_:
        # Genereate top_k random allocation
        outputs = torch.randint(
            low=0, high=args.num_classes, size=(N, K), dtype=torch.float, device='cuda'
        )
    else:
        cost_agents = criterion.get_agent_cost(None, labels, agents, selector=None)
        id_deferral, outputs = compute_scores_allocation(cost_agents, epsilon=1e-6, normalize=True)

    loss, loss_top_k, ids_top_k, acc_majority, acc_weighted, MRR, proba_top_k_np, accuracy_top_k, acc_model = metrics_top_k_l2d(criterion, None,
                                                                                                     None,
                                                                                                     labels,
                                                                                                     agents,
                                                                                                     outputs)
    # Deferral ratio id top_k
    deferral_ratio, querying_costs = compute_deferral_ratio(ids_top_k, args)
    mean_querying_costs = np.array(querying_costs).sum(axis=1)

    dict_results = {'accuracy_top_k': np.mean(accuracy_top_k.cpu().numpy(), axis=0)*100, 'loss_top_k': np.mean(loss_top_k, axis=0),
                    'acc_majority':np.mean(acc_majority.cpu().numpy(), axis=0)*100, 'acc_model': np.mean(acc_model.cpu().numpy(), axis=0),
                    'acc_weighted': np.mean(acc_weighted.cpu().numpy(), axis=0)*100, 'MRR': np.mean(MRR), 'proba_top_k': np.mean(proba_top_k_np, axis=0),
                    'deferral_ratio': deferral_ratio, 'querying_costs': querying_costs, 'mean_querying_costs': mean_querying_costs}
    return dict_results

# def eval_optimum(test_loader, args, criterion):
#     labels = torch.tensor(test_loader.dataset.targets)
#     agents = torch.tensor(test_loader.dataset.agents.T)
#
#     costs_model = criterion.get_cost(agents[:, 0], labels, expert=False)
#     costs_experts = torch.zeros((args.n_agents - 1, len(labels)), device=agents.device)
#     costs_agents = torch.cat((costs_model.unsqueeze(dim=0), costs_experts), dim=0)
#     for i in range(1, args.n_agents):
#         costs_agents[i] = criterion.get_cost(agents[:, i], labels, expert=True, i=i)
#
#     # break the tie by adding a small noise
#     id_deferral, scores_allocation = compute_scores_allocation(costs_agents, epsilon=1e-6, normalize=True)
#     accuracy_top_k = criterion.top_k_cost_sensitive(scores_allocation, labels, agents)
#     acc_majority, acc_weighted = criterion.indicator_vote(scores_allocation, labels, agents)
#     loss_top_k, ids_top_k = criterion.true_deferral_top_k(scores_allocation, labels, agents)
#     deferral_ratio, querying_costs = compute_deferral_ratio(ids_top_k, args)
#     mean_querying_costs = np.array(querying_costs).sum(axis=1)
#
#     dict_results = {'accuracy_top_k': np.mean(accuracy_top_k, axis=1) * 100, 'loss_top_k': np.mean(loss_top_k, axis=1),
#                     'acc_majority': np.mean(acc_majority, axis=0) * 100,
#                     'acc_weighted': np.mean(acc_weighted, axis=0) * 100,
#                     'deferral_ratio': deferral_ratio, 'querying_costs': querying_costs,
#                     'mean_querying_costs': mean_querying_costs}
#     return dict_results

def json_save(results, args, metrics,name):
    # Convert results to a JSON-serializable dictionary
    results_serializable = {}
    for m in metrics:
        mean_val, std_val = results[m]
        results_serializable[m] = {
            "mean": mean_val.tolist(),
            "std": std_val.tolist()
        }
    return results_serializable

def run_allocation(test_loader, args, criterion, name):
    results_random, results_optimal = None, None
    # average and std random classifier
    random_=True
    dict_results = []
    for i in range(10):
        random_seed(i)
        dict_results.append(eval_random_optimal(test_loader, args, criterion, random_=random_))
    # compute mean and std per top-k
    metrics = ['accuracy_top_k', 'loss_top_k', 'deferral_ratio', 'acc_majority', 'acc_weighted', 'querying_costs', 'mean_querying_costs', 'MRR', 'proba_top_k']
    # Compute mean and std for each metric
    results_random = {m: (np.mean(np.stack([d[m] for d in dict_results], axis=0), axis=0),
                   np.std(np.stack([d[m] for d in dict_results], axis=0), axis=0))
               for m in metrics}
    results_random = json_save(results_random, args, metrics, name)

    random_=False
    # optimal allocation
    dict_results_optimal = []
    for i in range(10):
        random_seed(i)
        dict_results_optimal.append(eval_random_optimal(test_loader, args, criterion, random_=random_))

    # compute mean and std per top-k
    metrics = ['accuracy_top_k', 'loss_top_k', 'deferral_ratio', 'acc_majority', 'acc_weighted', 'querying_costs', 'mean_querying_costs', 'MRR', 'proba_top_k']
    # Compute mean and std for each metric
    results_optimal = {m: (np.mean(np.stack([d[m] for d in dict_results_optimal], axis=0), axis=0),
                   np.std(np.stack([d[m] for d in dict_results_optimal], axis=0), axis=0))
               for m in metrics}
    results_optimal = json_save(results_optimal, args, metrics, name)

    return results_random, results_optimal

def main():
    os.environ["WANDB_INIT_TIMEOUT"] = "300"
    # Initialization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = setup_args()
    args.device = device
    random_seed(args.seed)
    criterion = L.L2D_loss(args)
    if args.dataset == 'cifar100':
        args.num_classes = 100
        predictor = load_model(args)
        train_loader, test_loader_subset, test_loader, dict_tr, dict_test = pre.processing(args.batch_size,
                                                                                           args.batch_size_eval,
                                                                                           args=args,
                                                                                           predictor=predictor)

    results_random, results_optimal = run_allocation(test_loader, args, criterion, args.name_run)
    json_save(results_random, args, ['accuracy_top_k', 'loss_top_k', 'deferral_ratio', 'querying_costs'], f"{args.name_run}_random")
    json_save(results_optimal, args, ['accuracy_top_k', 'loss_top_k', 'deferral_ratio', 'querying_costs'], f"{args.name_run}_optimal")


if __name__ == '__main__':
    main()



















