import copy
import torch
import torch.optim as optim
import random
import numpy as np
import wandb
import argparse as arg
import matplotlib.pyplot as plt
import os
from torch.optim.lr_scheduler import OneCycleLR
import json
import torch.nn.init as init
import preprocess as pre
from model import *
import loss as L
from train import *

def load_model(args):
    predictor = model_choice(args.predictor, depth=4, args=args, n_agents=args.num_classes).to(args.device)
    checkpoints = torch.load(f'Models_weights/{args.dataset}/{args.load_predictor}')
    predictor.load_state_dict(checkpoints['state_dict'])
    return predictor


def optimizer_scheduler(model, train_loader, args):
    if args.optimizer == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)

    if args.scheduler == 'cosine_restart':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=args.epochs // 20,  # Number of epochs before the first restart
            T_mult=1,  # Factor by which T_i is multiplied after a restart
            eta_min=args.lr / 20  # Minimum learning rate
        )
    elif args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*args.epochs, eta_min=args.lr/1000)
    elif args.scheduler == 'cyclelr':
        scheduler = OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs)
    else:
        scheduler = None

    return optimizer, scheduler


def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))  # Updated to use xavier_uniform_
        if m.bias is not None:
            init.constant_(m.bias, 0)  # Updated to use constant_

def setup_wandb(name_project, name, args=None):
    wandb.init(config = args, project=name_project, name=name)
    wandb.define_metric("current_step", step_metric="current_step")
    wandb.define_metric("current_epoch", step_metric="current_epoch")
    wandb.define_metric("current_step_card", step_metric="current_step_card")
    wandb.define_metric("current_epoch_card", step_metric="current_epoch_card")

    # Create deferral ratio column names (shared across tables)
    deferral_cols = [f'deferral_ratio_{j}' for j in range(args.n_agents + args.num_classes)]
    counts_cols = [f'counts_{j}' for j in range(args.n_agents + args.num_classes)]


    # For the "table_optimal": one table per agent, along with a persistent data list per table.
    columns_optimal = []
    table_optimal = []
    table_optimal_data = []  # a list of lists to store rows for each agent

    for i in range(args.n_agents + args.num_classes):
        cols = ["Run Name", "Steps", f"Accuracy Top-{i}", f'summed_query_top_{i}', f'deferral_loss_top-{i}']
        cols.extend(deferral_cols)  # add the common deferral ratio columns
        columns_optimal.append(cols)
        table_optimal.append(wandb.Table(columns=cols))
        table_optimal_data.append([])  # initialize empty list for this agent's rows

    # For the "table_cardinality": single table, with its own persistent data store.
    columns_cardinality = ['Run Name', 'Steps', 'Accuracy', 'acc_majority',
                           'acc_weighted', 'mean_querying_cost',  'true_deferral_loss', 'mean_k', 'cardinality_loss']
    columns_cardinality.extend(deferral_cols)
    columns_cardinality.extend(counts_cols)
    table_cardinality = wandb.Table(columns=columns_cardinality)
    table_cardinality_data = []  # persistent list for cardinality rows

    # Save these into args for later updates and logging
    args.table_optimal = table_optimal
    args.table_optimal_data = table_optimal_data
    args.table_cardinality = table_cardinality
    args.table_cardinality_data = table_cardinality_data

    return wandb


def random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return


def setup_args():
    parser = arg.ArgumentParser(description='L2D robust')
    parser.add_argument('--beta', nargs='+', default=[0.1], help='beta')
    parser.add_argument('--beta_model', type=float, default=0.05, help='lr')
    parser.add_argument('--n_agents', type=int, default=1, help='lambda_2')
    parser.add_argument('--batch_size', type=int, default="16", help='batch_size')
    parser.add_argument('--batch_size_eval', type=int, default=16, help='batch_size')
    parser.add_argument('--seed', type=int, default=42, help='batch_size')
    parser.add_argument('--epochs', type=int, default=10, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3, help='lr')
    parser.add_argument('--name_run', type=str, default='top_k_one_stage', help='batch_size')
    parser.add_argument('--name_project', type=str, default='top_k_one_stage', help='batch_size')
    parser.add_argument('--dataset', type=str, default='cifar10H', help='mnist, cifar10, cifar100, california')
    parser.add_argument('--task', type=str, default='classification', help='classification/regression/multi_task')
    parser.add_argument('--log_freq', type=int, default=10, help='evaluation frequency')
    parser.add_argument('--predictor', type=str, default='resnet', help='cnn/wideresnet/resnet18/resnet4')
    parser.add_argument('--load_predictor', type=str, default='model_resnet4_finalresnet-4.pth', help='loop=1, vectorized=0')
    parser.add_argument('--model', type=str, default='clip2', help='cnn/wideresnet/resnet18/resnet4_rej/MLP')
    parser.add_argument('--optimizer', type=str, default='Adam', help='AdamW/Adam/SGD')
    parser.add_argument('--dropout', type=float, default=0.0, help='vanilla/robust')
    parser.add_argument('--scheduler', type=str, default='cosine', help='linear/cosine/cyclelr')
    parser.add_argument('--subset_test', type=int, default=0, help='subset:1')
    parser.add_argument('--n_points', type=int, default=50, help='loop=1, vectorized=0')
    parser.add_argument('--device', type=str, default='cuda', help='overfit')
    parser.add_argument('--top_k', type=int, default=1, help='overfit')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='attack:1')
    parser.add_argument('--alpha', type=int, default=1, help='alpha')
    parser.add_argument('--depth_card', type=int, default=4, help='alpha')
    parser.add_argument('--depth', type=int, default=16, help='alpha')
    parser.add_argument('--decision', type=str, default='weighted', help='indicator/weighted/majority')
    parser.add_argument('--rejector', type=str, default='cnn', help='cnn/wideresnet/resnet18/resnet4_rej/MLP/rejector_regressor')
    parser.add_argument('--load_classifier', type=str, default='cifar10H_top_k_one_stage_top_k_one_stage_lr_0.001_seed_42_dropout_'
                                                             '0.0_scheduler_None_optimizer_Adam_cost_[0.01]/best_model_deferral_loss_top.pth',
                        help='cnn/wideresnet/resnet18/resnet4_rej/MLP')
    parser.add_argument('--Lambda', type=float, default=1e-10, help='alpha')

    args = parser.parse_args()

    args.beta = [float(i) for i in args.beta]
    if len(args.beta) != (args.n_agents): raise ValueError('beta must have the same length as n_agents')

    args.batch_size = int(args.batch_size)

    return args

def main():
    os.environ["WANDB_INIT_TIMEOUT"] = "300"
    # Initialization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args = setup_args()
    args.device = device
    random_seed(args.seed)

    if args.dataset == 'cifar100':
        args.num_classes = 100
        predictor = load_model(args)
    if args.dataset == 'cifar10':
        args.num_classes = 10
        predictor = None
    elif args.dataset == 'svhn':
        args.num_classes = 10
        predictor = None
    elif args.dataset == 'california':
        args.num_classes = 1
        predictor = None
    elif args.dataset == 'ames':
        args.num_classes = 1
        predictor = None
    elif args.dataset == 'cifar10H':
        args.num_classes = 10
        predictor = None

    train_loader, test_loader_subset, test_loader, dict_tr, dict_test, max_loss = pre.processing(args.batch_size,
                                                                                       args.batch_size_eval,
                                                                                       args=args,
                                                                                       predictor=predictor)
    model = model_choice(args.model, depth=args.depth, args=args, n_agents=args.n_agents + args.num_classes, dropout=args.dropout).to(
        device)
    if args.model == 'resnet':
        model.apply(conv_init)

    args.expert_tr = dict_tr
    args.expert_test = dict_test
    args.upper = max_loss

    if args.task == 'classification':
        rejector = model_choice(args.rejector, depth=args.depth, args=args, n_agents=args.n_agents + args.num_classes, dropout=args.dropout).to(device)
    else:
        rejector = model_choice(args.rejector, depth=args.depth, args=args, n_agents=args.n_agents, dropout=args.dropout).to(device)

    rejector.load_state_dict(torch.load(f'./checkpoints/{args.load_classifier}'))

    criterion = L.L2D_loss(args)
    name_run = (f'{args.dataset}_{args.name_run}_{args.name_project}_lr_{args.lr}_seed_{args.seed}_dropout_{args.dropout}_scheduler_{args.scheduler}'
                f'_optimizer_{args.optimizer}_cost_{args.beta}_Lambda_{args.Lambda}_selector_{args.model}{args.depth_card}_{args.decision}')
    setup_wandb(args.name_project, name_run, args=args)
    optimizer, scheduler = optimizer_scheduler(model, train_loader, args)

    # Training
    train_cardinality(model, rejector, criterion, optimizer, train_loader, test_loader_subset, args.epochs,
                                     device, scheduler=scheduler, args=args, name=name_run, wandb=wandb)



if __name__ == '__main__':
    main()
















