import torch
import numpy as np
import os
import json
from tqdm.auto import tqdm
import copy
from loss import vote_majority_loss, vote_weighted_loss

def images_process(images, labels, agents,args):
    # if args.dataset == 'cifar100' or args.dataset == 'cifar10' or args.dataset == 'svhn' or args.dataset == 'california' :
    #     images, labels, agents = images.to(args.device), labels.to(args.device), agents.to(args.device)
    # else:
    #     images = [img.to(args.device) for img in images]
    #     labels = [{k: v.to(args.device) for k, v in t.items()} for t in labels]
    #     agents = np.stack(agents, axis=1).T
    images, labels, agents = images.to(args.device), labels.to(args.device), agents.to(args.device)
    return images, labels, agents

def scheduler_step(scheduler, epoch, i, len_train_loader, args):
    if scheduler is None:
        return scheduler  # Return early if no scheduler is used
    if args.scheduler == 'cyclelr':
        scheduler.step()  # Step after optimizer
    elif args.scheduler == 'cosine_restart':
        scheduler.step(epoch + i / len_train_loader)  # This assumes T_mult > 1
    elif args.scheduler == 'cosine':
        scheduler.step()  # Step after optimizer
    return scheduler


def compute_deferral_ratio(deferral_ids_top_k, args):
    num_agents = args.n_agents + args.num_classes
    num_samples = deferral_ids_top_k.shape[0]
    max_k = num_agents
    percentages = []
    beta_costs= []
    for k in range(0, max_k):
        # Get the top-k predictions for each sample (shape: (num_samples, k))
        top_k_slice = deferral_ids_top_k[:, :k + 1]
        # (num_samples, k, num_agents), where each row's prediction becomes a one-hot vector.
        one_hot = np.eye(num_agents, dtype=int)[top_k_slice]
        # Sum over samples and the k predictions to get counts per agent.
        counts = one_hot.sum(axis=(0, 1))
        # Compute percentages: total predictions
        percentages.append(counts / num_samples)
        tmp = np.zeros(args.num_classes)
        beta_costs.append((counts / num_samples) * np.concatenate((tmp, np.array(args.beta))))
    return percentages, beta_costs

def train_model(model, criterion, optimizer, train_loader, test_loader, num_epochs, device,  scheduler = None, args=None, name=None, wandb=None):
    # global steps
    steps = 0
    custom_epoch = 0
    store_metric = {'acc': 0, 'deferral_loss': 1000}
    saving_model = None
    num_training_steps = num_epochs * len(train_loader)
    progress_bar = tqdm(range(num_training_steps))

    for epoch in range(num_epochs):
        running_loss = 0.0
        cum_store = 0
        for i, (images, labels, agents) in enumerate(train_loader):
            # Set model to training mode
            model.train()
            images, labels, agents = images_process(images, labels, agents, args)
            # Forward pass
            # Compute loss
            if args.task == 'classification':
                outputs = model(images)
                model_l = (torch.arange(args.num_classes, device='cuda')).repeat(len(labels), 1)
                agents_ = torch.concatenate((model_l, agents), dim=1)
            else:
                outputs, regression = model(images)
                agents_ = torch.concatenate((regression, agents), dim=1)

            loss = criterion(outputs, labels, agents_, images=images, model=model, val=False)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler = scheduler_step(scheduler, epoch, i, len(train_loader), args)

            # Metrics
            running_loss += loss.item()
            wandb.log({'Train/loss_iteration': loss.item(), 'current_step': steps}, commit=True)

            # Evaluation
            if steps % args.log_freq == 0:
                store_metric, saving_model = eval(model, criterion, test_loader, device, store_metric, args, name, wandb, steps=steps, best_model=saving_model)

            progress_bar.update(1)
            steps += 1

        wandb.log({
            'Train/loss_epoch': running_loss / len(train_loader),
            'current_epoch': custom_epoch,
        }, commit=True)

        custom_epoch += 1
        if scheduler is not None and args.scheduler == 'cosine':
            scheduler.step()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}', flush=True)

    return saving_model

def accuracy_model(outputs_rejector, labels, criterion, agents):
    if criterion.args.task == 'classification':
        model_prediction = outputs_rejector[:,:criterion.args.num_classes]
        argmax_pred = torch.argmax(model_prediction, dim=1)
        error = (argmax_pred != labels).float()
    else:
        pred_system = agents[:,:criterion.args.num_classes]
        error = criterion.rmse(pred_system, labels.unsqueeze(1))
    return error

def metrics_top_k_l2d(criterion, model, images, labels, agents, outputs_rejector):
    # Compute loss
    loss = criterion(outputs_rejector, labels, agents, images=images, model=model, val=True)
    loss_top_k, ids_top_k = criterion.true_deferral_top_k(outputs_rejector, labels, agents)
    # vote majority
    id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
    if criterion.args.task == 'classification':
        acc_majority = 1 - vote_majority_loss(id_deferral, labels, agents, selector=None, args=criterion.args,
                                              rmse=criterion.rmse)
        acc_weighted = 1 - vote_weighted_loss(outputs_rejector, labels, agents, selector=None,
                                              args=criterion.args, rmse=criterion.rmse)
        acc_top_k = 1 - criterion.cost_cardinality(agents, labels, id_deferral)

        acc_model = 1 - accuracy_model(outputs_rejector, labels, criterion, agents)

    elif criterion.args.task == 'regression':
        acc_majority = vote_majority_loss(id_deferral, labels, agents, selector=None, args=criterion.args,
                                              rmse=criterion.rmse)
        acc_weighted = vote_weighted_loss(outputs_rejector, labels, agents, selector=None,
                                              args=criterion.args, rmse=criterion.rmse)
        acc_top_k = criterion.cost_cardinality(agents, labels, id_deferral)
        acc_model = accuracy_model(outputs_rejector, labels, criterion, agents)
    # MRR and proba_top_k
    MRR = criterion.MRR(outputs_rejector, labels, agents)
    proba_top_k_np = criterion.proba_top_k(outputs_rejector)
    return loss, loss_top_k, ids_top_k, acc_majority, acc_weighted, MRR, proba_top_k_np, acc_top_k, acc_model

def wandb_log_top_k(deferral_loss_top_k, accuracy_top_k, mean_querying_costs,
              acc_majority_list, acc_weighted_list, deferral_ratio, MRR_list, running_loss, test_loader, proba_top_k_list, acc_model_list,
              steps, args, wandb, name=None):
        # Log metrics
        for k in range(args.n_agents + args.num_classes):
            wandb.log({
                f'{name}/deferral_loss_top_{k}': deferral_loss_top_k[k],
                f'{name}/acc_majority_top_{k}': acc_majority_list[k],
                f'{name}/acc_weighted_top_{k}': acc_weighted_list[k],
                f'{name}/accuracy_top_{k}': accuracy_top_k[k],
                f'{name}/mean_querying_costs_{k}': mean_querying_costs[k],
                'current_step': steps,
            }, commit=True)
            for j in range(args.n_agents + args.num_classes):
                wandb.log({
                    f'{name}/deferral_ratio_{j}_top_{k}': deferral_ratio[k][j],
                    'current_step': steps,
                }, commit=True)

        wandb.log({f'{name}/surrogate_loss': running_loss/len(test_loader),
                   f'{name}/MRR': MRR_list.item(),
                   f'{name}/acc_model': acc_model_list.item(),
                   'current_step': steps}, commit=True)

        for j in range(args.n_agents):
            wandb.log({
                f'{name}/proba_top_{j}': proba_top_k_list[j],
                'current_step': steps,
            }, commit=True)


def eval(model, criterion, test_loader, device, store_metric, args=None, name=None, wandb=None, steps=0, best_model=None):
    model.eval()
    with torch.no_grad():
        running_loss = 0
        deferral_loss_top_k, accuracy_top_k, deferral_ids_top_k = [], [], []
        acc_majority_list, acc_weighted_list = [], []
        MRR_list, proba_top_k_list = [], []
        acc_model_list = []
        for images, labels, agents in test_loader:
            images, labels, agents = images_process(images, labels, agents, args)
            if args.task == 'classification':
                outputs = model(images)
                model_l = (torch.arange(args.num_classes, device='cuda')).repeat(len(labels), 1)
                agents_ = torch.concatenate((model_l, agents), dim=1)
            else:
                outputs, regression = model(images)
                agents_ = torch.concatenate((regression, agents), dim=1)
            # Compute loss
            loss, loss_top_k, ids_top_k, acc_majority, acc_weighted, MRR, proba_top_k_np, acc_top_k, acc_model = metrics_top_k_l2d(criterion, model,
                                                                                                   images,
                                                                                                   labels,
                                                                                                   agents_,
                                                                                                   outputs)
            # Store metrics
            deferral_loss_top_k.extend(loss_top_k)
            accuracy_top_k.extend(acc_top_k.cpu().numpy())
            acc_weighted_list.extend(acc_weighted.cpu().numpy())
            acc_majority_list.extend(acc_majority.cpu().numpy())
            acc_model_list.extend(acc_model.cpu().numpy())
            deferral_ids_top_k.extend(ids_top_k)
            MRR_list.extend(MRR)
            proba_top_k_list.extend(proba_top_k_np)
            running_loss += loss.item()

        # Log metrics
        acc_model_list = np.array(acc_model_list).mean()*100
        deferral_loss_top_k = np.array(deferral_loss_top_k).mean(axis=0)
        accuracy_top_k = np.array(accuracy_top_k).mean(axis=0)*100
        deferral_ids_top_k = np.array(deferral_ids_top_k)
        acc_majority_list = np.array(acc_majority_list).mean(axis=0)*100
        acc_weighted_list = np.array(acc_weighted_list).mean(axis=0)*100
        MRR_list = np.array(MRR_list).mean(axis=0)
        proba_top_k_list = np.array(proba_top_k_list).mean(axis=0)

        # Deferral ratio id top_k
        deferral_ratio, querying_costs = compute_deferral_ratio(deferral_ids_top_k, args)
        mean_querying_costs = np.array(querying_costs).sum(axis=1)

        # Log metrics
        wandb_log_top_k(deferral_loss_top_k, accuracy_top_k, mean_querying_costs,
                        acc_majority_list, acc_weighted_list, deferral_ratio, MRR_list, running_loss, test_loader,
                        proba_top_k_list, acc_model_list,
                        steps, args, wandb, name='Val')

        # Printing
        print(f'Validation Loss: {running_loss/len(test_loader):.4f}', flush=True)
        print(f'Accuracy top 1: {accuracy_top_k[0]:.4f}', flush=True)
        print(f'Deferral Loss top 1: {deferral_loss_top_k[0]:.4f}', flush=True)
        print(f'Accuracy Model: {acc_model_list:.4f}', flush=True)


        if deferral_loss_top_k[0] < store_metric['deferral_loss']:
            deferral_loss_store = deferral_loss_top_k[0]
            os.makedirs(f'./checkpoints/{name}', exist_ok=True)
            best_model = copy.deepcopy(model)
            torch.save(best_model.state_dict(), f'./checkpoints/{name}/best_model_deferral_loss_top.pth')
            # Save json metrics
            dict_metrics = {
                'accuracy_model': acc_model_list.tolist(),
                'deferral_loss_top_k': deferral_loss_top_k.tolist(),
                'accuracy_top_k': accuracy_top_k.tolist(),
                'majority_top_k': acc_majority_list.tolist(),
                'weighted_top_k': acc_weighted_list.tolist(),
                'deferral_ratio': np.array(deferral_ratio).tolist(),
                'mean_querying_costs': mean_querying_costs.tolist(),
                'MRR': MRR_list.tolist(),
                'proba_top_k': proba_top_k_list.tolist(),
                'steps': steps
            }
            with open(f'./checkpoints/{name}/metrics_deferral_loss_top.json', 'w') as f:
                json.dump(dict_metrics, f)

                # Log metrics
            wandb_log_top_k(deferral_loss_top_k, accuracy_top_k, mean_querying_costs,
                            acc_majority_list, acc_weighted_list, deferral_ratio, MRR_list, running_loss, test_loader,
                            proba_top_k_list, acc_model_list,
                            steps, args, wandb, name='Optimal')

            for k in range(args.n_agents + args.num_classes):
                args.table_optimal_data[k] = [wandb.run.name, steps, acc_model_list, accuracy_top_k[k], acc_majority_list[k], acc_weighted_list[k]
                    ,mean_querying_costs[k], deferral_loss_top_k[k], MRR_list, *deferral_ratio[k], *proba_top_k_list]
                new_table_optimal = wandb.Table(
                    columns=args.table_optimal[k].columns,
                    data=[args.table_optimal_data[k]]
                )
                wandb.log({f"Table Optimal top-{k} L2D": new_table_optimal}, commit=True)


            store_metric['deferral_loss'] = deferral_loss_store

    return store_metric, best_model

def metrics_cardinality(criterion, model, images, labels, agents, rejector, outputs_selector, outputs_rejector):
    id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)

    cardinality_loss, k_pred, accuracy, cum_cost, deferral_perc = criterion.cardinality_aware_true(rejector, images,
                                                                                                   outputs_selector, labels,
                                                                                                   agents)
    if criterion.args.task == 'classification':
        acc_majority = 1 - vote_majority_loss(id_deferral, labels, agents, selector=outputs_selector, args=criterion.args, rmse=criterion.rmse)
        acc_weighted = 1 - vote_weighted_loss(outputs_rejector, labels, agents, selector=outputs_selector, args=criterion.args, rmse=criterion.rmse)
    elif criterion.args.task == 'regression':
        acc_majority = vote_majority_loss(id_deferral, labels, agents, selector=outputs_selector, args=criterion.args, rmse=criterion.rmse)
        acc_weighted = vote_weighted_loss(outputs_rejector, labels, agents, selector=outputs_selector, args=criterion.args, rmse=criterion.rmse)
    deferral_loss, ids_top_k = criterion.true_deferral_top_k(outputs_rejector, labels, agents, selector=outputs_selector)
    loss = criterion.cardinality_aware(rejector, images, outputs_selector, labels, agents)
    # MRR = criterion.MRR(outputs_rejector, labels, agents)
    # proba_top_k_np = criterion.proba_top_k(outputs_rejector)
    return (cardinality_loss, k_pred, accuracy, cum_cost, deferral_perc,
            acc_majority, acc_weighted, deferral_loss, ids_top_k, loss)

def eval_cardinality(model, rejector, criterion, test_loader, device, store_metric, args, name, wandb, steps=0):
    model.eval()
    with (torch.no_grad()):
        running_loss = 0
        cardinality_loss_top_k, accuracy_top_k, k_pred_list, cum_cost_top_k = [], [], [], []
        acc_majority_list, acc_weighted_list = [], []
        true_deferral_loss_top_k = []
        deferral_perc_list = 0
        for images, labels, agents in test_loader:
            images, labels, agents = images_process(images, labels, agents, args)
            outputs = model(images)
            rejector.eval()
            if args.task == 'classification':
                model_l = (torch.arange(args.num_classes, device='cuda')).repeat(len(labels), 1)
                agents_ = torch.concatenate((model_l, agents), dim=1)
                outputs_rejector = rejector(images)
                (cardinality_loss, k_pred, accuracy, cum_cost, deferral_perc,
                 acc_majority, acc_weighted, deferral_loss, ids_top_k, loss) = metrics_cardinality(criterion, model,
                                                                                                   images,
                                                                                                   labels,
                                                                                                   agents_,
                                                                                                   rejector,
                                                                                                   outputs,
                                                                                                   outputs_rejector)

            else:
                outputs_rejector, regression = rejector(images)
                agents_ = torch.concatenate((regression, agents), dim=1)
                (cardinality_loss, k_pred, accuracy, cum_cost, deferral_perc,
                 acc_majority, acc_weighted, deferral_loss, ids_top_k, loss) = metrics_cardinality(criterion, model,
                                                                                                   images,
                                                                                                   labels,
                                                                                                   agents_,
                                                                                                   rejector.classifier,
                                                                                                   outputs,
                                                                                                   outputs_rejector)
            # Compute metric

            # Store metrics
            cardinality_loss_top_k.extend(cardinality_loss)
            accuracy_top_k.extend(accuracy)
            k_pred_list.extend(k_pred)
            cum_cost_top_k.extend(cum_cost)
            acc_majority_list.extend(acc_majority.squeeze().cpu().numpy())
            acc_weighted_list.extend(acc_weighted.squeeze().cpu().numpy())
            true_deferral_loss_top_k.extend(deferral_loss)
            deferral_perc_list += deferral_perc
            running_loss += loss.item()

        # Log metrics
        cardinality_loss_top_k = np.array(cardinality_loss_top_k).mean(axis=0)
        true_deferral_loss_top_k = np.array(true_deferral_loss_top_k).mean(axis=0)
        accuracy_top_k = np.array(accuracy_top_k).mean(axis=0)*100
        mean_k = np.array(k_pred_list).mean(axis=0)
        counts = np.bincount(np.array(k_pred_list), minlength=args.n_agents + args.num_classes)
        querying_costs_mean = np.array(cum_cost_top_k).mean(axis=0)
        acc_majority_list = np.array(acc_majority_list).mean(axis=0)*100
        acc_weighted_list = np.array(acc_weighted_list).mean(axis=0)*100
        deferral_ratio = deferral_perc_list/len(k_pred_list)

        # Log metrics
        wandb_log(cardinality_loss_top_k, accuracy_top_k, mean_k,
                  querying_costs_mean, acc_majority_list, acc_weighted_list, deferral_ratio, true_deferral_loss_top_k, counts,steps, args, wandb, name='Cardinality')

        # Printing
        print(f'Validation Loss: {running_loss / len(test_loader):.4f}', flush=True)
        print(f'Accuracy top 1: {accuracy_top_k.item():.4f}', flush=True)
        print(f'Deferral Loss top 1: {cardinality_loss_top_k.item():.4f}', flush=True)


        if cardinality_loss_top_k.item() < store_metric['deferral_loss']:
            deferral_loss_store = cardinality_loss_top_k.item()
            os.makedirs(f'./checkpoints/{name}', exist_ok=True)
            torch.save(model.state_dict(), f'./checkpoints/{name}/cardinality_best_model_deferral_loss_top.pth')
            # Save json metrics
            dict_metrics = {
                'cardinality_loss_top_k': cardinality_loss_top_k.tolist(),
                'true_deferral_loss_top_k': true_deferral_loss_top_k.tolist(),
                'accuracy_top_k': accuracy_top_k.tolist(),
                'deferral_ratio': np.array(deferral_ratio).tolist(),
                'mean_querying_costs': querying_costs_mean.tolist(),
                'mean_k': mean_k.item(),
                'majority_top_k': acc_majority_list.tolist(),
                'weighted_top_k': acc_weighted_list.tolist(),
                'counts': np.array(counts).tolist(),
                'steps': steps
            }
            with open(f'./checkpoints/{name}/cardinality_metrics_loss_top.json', 'w') as f:
                json.dump(dict_metrics, f)

            # Log metrics
            wandb_log(cardinality_loss_top_k, accuracy_top_k, mean_k, querying_costs_mean,
                      acc_majority_list, acc_weighted_list, deferral_ratio, true_deferral_loss_top_k, counts, steps, args, wandb, name='Optimal')

            args.table_cardinality_data = [wandb.run.name, steps,
                                                accuracy_top_k.item(), acc_majority_list.item(),
                                                acc_weighted_list.item(), querying_costs_mean.item(),
                                                true_deferral_loss_top_k.item(), mean_k.item(),
                                                cardinality_loss_top_k.item(), *deferral_ratio, *counts]
            new_table_cardinality = wandb.Table(
                columns=args.table_cardinality.columns,
                data=[args.table_cardinality_data]
            )
            wandb.log({"Table Optimal Cardinality": new_table_cardinality}, commit=True)
            store_metric['deferral_loss'] = deferral_loss_store

    return store_metric

def wandb_log(deferral_loss_top_k, accuracy_top_k, mean_k, querying_costs_mean,
              acc_majority_list, acc_weighted_list, deferral_ratio, true_deferral_loss_top_k, counts,
              steps, args, wandb, name=None):
    # Log metrics
    wandb.log({
        f'{name}/cardinality_loss': deferral_loss_top_k.item(),
        f'{name}/true_deferral_loss': true_deferral_loss_top_k.item(),
        f'{name}/accuracy': accuracy_top_k.item(),
        f'{name}/mean_k': mean_k.item(),
        f'{name}/acc_majority_list': acc_majority_list.item(),
        f'{name}/acc_weighted_list': acc_weighted_list.item(),
        f'{name}/mean_querying_costs': querying_costs_mean.item(),
        'current_step_card': steps,
    }, commit=True)
    for j in range(args.n_agents + args.num_classes):
        wandb.log({
            f'{name}/deferral_ratio_{j}': deferral_ratio[j],
            f'{name}/count_{j}': counts[j],
            'current_step_card': steps,
        }, commit=True)


def train_cardinality(model, rejector, criterion, optimizer, train_loader, test_loader, num_epochs, device,  scheduler = None, args=None, name=None, wandb=None):
    # global steps
    steps = 0
    custom_epoch = 0
    store_metric = {'acc': 0, 'deferral_loss': 1000}
    saving_model = None
    num_training_steps = num_epochs * len(train_loader)
    progress_bar = tqdm(range(num_training_steps))

    for epoch in range(num_epochs):
        running_loss = 0.0
        cum_store = 0
        for i, (images, labels, agents) in enumerate(train_loader):
            # Set model to training mode
            model.train()
            rejector.eval()
            images, labels, agents = images_process(images, labels, agents, args)
            # Forward pass
            outputs = model(images)
            if args.task == 'classification':
                model_l = (torch.arange(args.num_classes, device='cuda')).repeat(len(labels), 1)
                agents_ = torch.concatenate((model_l, agents), dim=1).detach()
                loss = criterion.cardinality_aware(rejector, images, outputs, labels, agents_)
            else:
                _, regression = rejector(images)
                agents_ = torch.concatenate((regression, agents), dim=1).detach()
                loss = criterion.cardinality_aware(rejector.classifier, images, outputs, labels, agents_)
            # Compute loss
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler = scheduler_step(scheduler, epoch, i, len(train_loader), args)

            # Metrics
            running_loss += loss.item()
            wandb.log({'Train/loss_iteration': loss.item(),
                       'lr': optimizer.param_groups[0]['lr'],
                       'current_step_card': steps}, commit=True)

            # Evaluation
            if steps % args.log_freq == 0:
                store_metric = eval_cardinality(model, rejector, criterion, test_loader, device, store_metric, args, name, wandb, steps=steps)

            progress_bar.update(1)
            steps += 1

        wandb.log({
            'Train/loss_epoch': running_loss / len(train_loader),
            'current_epoch_card': custom_epoch,
        }, commit=True)

        custom_epoch += 1
        if scheduler is not None and args.scheduler == 'cosine':
            scheduler.step()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}', flush=True)
    return saving_model

