import copy
import torch
import torch.optim as optim
import random
import numpy as np
import wandb
import argparse as arg
import matplotlib.pyplot as plt
import os
from torch.optim.lr_scheduler import OneCycleLR
import json
import torch.nn.init as init
import preprocess as pre
from model import *
import loss as L
from train import *
from Random_allocation import run_allocation

def load_model(args):
    predictor = model_choice(args.predictor, depth=4, args=args, n_agents=args.num_classes).to(args.device)
    checkpoints = torch.load(f'Models_weights/{args.dataset}/{args.load_predictor}')
    predictor.load_state_dict(checkpoints['state_dict'])
    return predictor


def optimizer_scheduler(model, train_loader, args):
    if args.optimizer == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)

    if args.scheduler == 'cosine_restart':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=args.epochs // 20,  # Number of epochs before the first restart
            T_mult=1,  # Factor by which T_i is multiplied after a restart
            eta_min=args.lr / 20  # Minimum learning rate
        )
    elif args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*args.epochs, eta_min=args.lr/1000)
    elif args.scheduler == 'cyclelr':
        scheduler = OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs)
    else:
        scheduler = None

    return optimizer, scheduler


def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))  # Updated to use xavier_uniform_
        if m.bias is not None:
            init.constant_(m.bias, 0)  # Updated to use constant_

def setup_wandb(name_project, name, args=None):
    wandb.init(config = args, project=name_project, name=name)
    wandb.define_metric("current_step", step_metric="current_step")
    wandb.define_metric("current_epoch", step_metric="current_epoch")
    wandb.define_metric("current_step_card", step_metric="current_step_card")
    wandb.define_metric("current_epoch_card", step_metric="current_epoch_card")

    # Create deferral ratio column names (shared across tables)
    deferral_cols = [f'deferral_ratio_{j}' for j in range(args.n_agents + args.num_classes)]
    proba_cols = [f'proba_{j}' for j in range(args.n_agents + args.num_classes)]

    # For the "table_optimal": one table per agent, along with a persistent data list per table.
    columns_optimal = []
    table_optimal = []
    table_optimal_data = []  # a list of lists to store rows for each agent

    for i in range(args.n_agents + args.num_classes):
        cols = ["Run Name", "Steps", 'Accuracy_model', f"Accuracy Top-{i}", f"acc_majority Top-{i}", f'acc_weighted Top-{i}'
            ,f'mean_query_top_{i}', f'deferral_loss_top-{i}', 'MRR']
        cols.extend(deferral_cols) # add the common deferral ratio columns
        cols.extend(proba_cols)
        columns_optimal.append(cols)
        table_optimal.append(wandb.Table(columns=cols))
        table_optimal_data.append([])  # initialize empty list for this agent's rows


    # For the "table_cardinality": single table, with its own persistent data store.
    columns_cardinality = ['Run Name', 'Steps', 'Accuracy_model', 'Accuracy', 'acc_majority', 'acc_weighted', 'true_cost', 'mean_querying_cost', 'mean_k']
    columns_cardinality.extend(deferral_cols)
    columns_cardinality.extend(proba_cols)
    table_cardinality = wandb.Table(columns=columns_cardinality)
    table_cardinality_data = []  # persistent list for cardinality rows

    # Save these into args for later updates and logging
    args.table_optimal = table_optimal
    args.table_optimal_data = table_optimal_data
    args.table_cardinality = table_cardinality
    args.table_cardinality_data = table_cardinality_data

    return wandb

def random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return


def setup_args():
    parser = arg.ArgumentParser(description='L2D robust')
    parser.add_argument('--beta', nargs='+', type=float, default=[0.055, 0.05, 0.045, 0.04, 0.035, 0.03], help='beta')
    # parser.add_argument('--beta', nargs='+', type=float, default=[0.01], help='beta')
    parser.add_argument('--beta_model', type=float, default=0.0, help='lr')
    # parser.add_argument('--n_agents', type=int, default=6, help='lambda_2')
    parser.add_argument('--n_agents', type=int, default=6, help='lambda_2')
    parser.add_argument('--batch_size', type=int, default=8, help='batch_size')
    parser.add_argument('--batch_size_eval', type=int, default=8, help='batch_size')
    parser.add_argument('--seed', type=int, default=42, help='batch_size')
    parser.add_argument('--epochs', type=int, default=0, help='batch_size')
    parser.add_argument('--epochs_cardinality', type=int, default=1, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3, help='lr')
    parser.add_argument('--name_run', type=str, default='top_k_one_stage', help='batch_size')
    parser.add_argument('--name_project', type=str, default='top_k_one_stage2', help='batch_size')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar100, california, svhn, ames')
    parser.add_argument('--task', type=str, default='classification', help='classification/regression')
    parser.add_argument('--log_freq', type=int, default=30, help='evaluation frequency')
    parser.add_argument('--predictor', type=str, default='cnn', help='cnn/wideresnet/resnet18/resnet4')
    parser.add_argument('--load_predictor', type=str, default='model_resnet4_finalresnet-4.pth', help='loop=1, vectorized=0')
    parser.add_argument('--model', type=str, default='cnn', help='cnn/wideresnet/resnet18/resnet/MLP')
    parser.add_argument('--optimizer', type=str, default='Adam', help='AdamW/Adam/SGD')
    parser.add_argument('--dropout', type=float, default=0.0, help='vanilla/robust')
    parser.add_argument('--scheduler', type=str, default='None', help='linear/cosine/cyclelr')
    parser.add_argument('--subset_test', type=int, default=0, help='subset:1')
    parser.add_argument('--n_points', type=int, default=-1, help='loop=1, vectorized=0')
    parser.add_argument('--device', type=str, default='cuda', help='overfit')
    parser.add_argument('--top_k', type=int, default=1, help='overfit')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='attack:1')
    parser.add_argument('--alpha', type=int, default=1, help='alpha')
    parser.add_argument('--depth', type=int, default=16, help='alpha')
    parser.add_argument('--depth_card', type=int, default=16, help='alpha')
    parser.add_argument('--Lambda', type=float, default=0.01, help='alpha')
    parser.add_argument('--cardinality', type=int, default=0, help='alpha')
    parser.add_argument('--random_optimal', type=int, default=1, help='alpha')
    parser.add_argument('--lr_card', type=float, default=1e-3, help='alpha')
    args = parser.parse_args()

    args.lr = float(args.lr)
    args.beta = [float(i) for i in args.beta]
    if len(args.beta) != (args.n_agents): raise ValueError('beta must have the same length as n_agents')

    return args

def main():
    os.environ["WANDB_INIT_TIMEOUT"] = "300"
    # Initialization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = setup_args()
    args.device = device
    random_seed(args.seed)

    if args.dataset == 'cifar100':
        args.num_classes = 100
        predictor = load_model(args)
    if args.dataset == 'cifar10':
        args.num_classes = 10
        predictor = None
    if args.dataset == 'cifar10H':
        args.num_classes = 10
        predictor = None
    elif args.dataset == 'svhn':
        args.num_classes = 10
        predictor = None
    elif args.dataset == 'california':
        args.num_classes = 1
        predictor = None
    elif args.dataset == 'ames':
        args.num_classes = 1
        predictor = None

    train_loader, test_loader_subset, test_loader, dict_tr, dict_test, max_loss = pre.processing(args.batch_size,
                                                                                           args.batch_size_eval,
                                                                                           args=args,
                                                                                           predictor=predictor)
    if args.task == 'classification':
        model = model_choice(args.model, depth=args.depth, args=args, n_agents=args.n_agents + args.num_classes, dropout=args.dropout).to(device)
    else:
        model = model_choice(args.model, depth=args.depth, args=args, n_agents=args.n_agents, dropout=args.dropout).to(device)

    if args.model == 'resnet':
        model.apply(conv_init)

    args.expert_tr = dict_tr
    args.expert_test = dict_test
    args.upper = max_loss

    criterion = L.L2D_loss(args)
    name_run = (f'{args.dataset}_{args.name_run}_{args.name_project}_lr_{args.lr}_seed_{args.seed}_dropout_{args.dropout}_scheduler_{args.scheduler}'
                f'_optimizer_{args.optimizer}_cost_{args.beta}')
    setup_wandb(args.name_project, name_run, args=args)
    #create folder
    if not os.path.exists(f'./checkpoints/{name_run}'):
        os.makedirs(f'./checkpoints/{name_run}')
    optimizer, scheduler = optimizer_scheduler(model, train_loader, args)

    if args.random_optimal:
        results_random, results_optimal = run_allocation(test_loader, args, criterion, name_run)

        save_path = f'./checkpoints/{name_run}/'
        # Save JSON files
        random_json_path = os.path.join(save_path, 'random_allocation.json')
        optimal_json_path = os.path.join(save_path, 'optimal_allocation.json')

        with open(random_json_path, 'w') as f:
            json.dump(results_random, f)

        with open(optimal_json_path, 'w') as f:
            json.dump(results_optimal, f)

        # Log artifacts with correct paths
        wandb.summary["results_random"] = results_random
        wandb.summary["results_optimal"] = results_optimal

    # Training ------------------------------------------
    if args.epochs != 0:
        saving_model = train_model(model, criterion, optimizer, train_loader, test_loader_subset, args.epochs,
                               device, scheduler=scheduler, args=args, name=name_run, wandb=wandb)

    # Cardinality-aware
    if args.cardinality:
        rejector = copy.deepcopy(saving_model)
        rejector.eval()
        selector = model_choice(args.model, depth=args.depth_card, args=args, n_agents=args.n_agents, dropout=args.dropout).to(
            device)
        if args.model == 'resnet':
            selector.apply(conv_init)
        args.lr = args.lr_card
        optimizer_card, scheduler_card = optimizer_scheduler(selector, train_loader, args)
        train_cardinality(selector, rejector, criterion, optimizer_card, train_loader, test_loader_subset,
                          args.epochs_cardinality,
                          device, scheduler=scheduler_card, args=args, name=name_run, wandb=wandb)
    wandb.finish()

if __name__ == '__main__':
    main()
















