import torchmetrics
import numpy as np
from model import *
import torch
from auto_LiRPA import BoundedModule, CrossEntropyWrapper
import torch.nn.functional as F
from train_verified import train_verified
from continuum import ClassIncremental, InstanceIncremental, Permutations, TransformationIncremental
from continuum.datasets import MNIST, CIFAR10, FashionMNIST, Synbols, CIFAR100, Core50, TinyImageNet200
import torch.nn.functional as F
from evaluate import evaluate_model
from loguru import logger
from data.RUARobot import get_RUARobot_datasets
from data.Medical import get_medical_datasets
from utils.training import get_embedded_dataset


def multiple_runs(args):
    all_accs = []
    all_certs = []
    for k in range(args.num_runs):
        acc, cert = run_experiment(args)
        logger.info('Avg Cert: {:.4f}'.format(cert))
        all_accs.append(acc)
        all_certs.append(cert)
    if not hasattr(args, 'forget_vals'):
        args.forget_vals = [0]
    logger.info('Final acc of all runs: mean:{:.4f} std:{:.4f}'.format(np.mean(all_accs), np.std(all_accs)))
    logger.info('Final cert of all runs: mean:{:.4f} std:{:.4f}'.format(np.mean(all_certs), np.std(all_certs)))
    logger.info('Final forgetting of all runs: mean:{:.4f} std:{:.4f}'.format(np.mean(args.forget_vals), np.std(args.forget_vals)))
    if args.wandb:
        import wandb
        wandb.log({'final_acc':np.mean(all_accs),
            'final_std_acc':np.std(all_accs),
            'final_cert': np.mean(all_certs),
            'final_std_cert': np.mean(all_certs),
            'final_forgetting': np.mean(args.forget_vals),
            'final_std_forgetting': np.std(args.forget_vals),
        })

def run_experiment(args):
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    data_path = args.data_path
    class_inc = args.class_inc

    if args.embed_img:
        from torchvision.models import vit_b_16, ViT_B_16_Weights
        args.embed_transform = ViT_B_16_Weights.IMAGENET1K_V1.transforms()
        args.embedder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        args.embedder.eval()
        args.embedder.cuda()
    
    if args.dataset == "mnist":
        dataset = MNIST(data_path, download=True, train=True)
        scenario = ClassIncremental(
            dataset,
            increment=class_inc,
        )
        dummy_input = torch.randn(2, 1, 28, 28).to(args.device)
        if args.embed_img:
            in_dim = 768
        else:
            in_dim = 1*28*28
        out_dim = 10
        dataset_test = MNIST(data_path, download=True, train=False)
        test_scenario = ClassIncremental(
            dataset_test,
            increment=2
        )
    elif args.dataset == "tinyimg":
        dataset = TinyImageNet200(data_path, download=True, train=True)
        dataset_test = TinyImageNet200(data_path, download=True, train=False)
        args.n_classes = 200
        args.class_inc = 20
        if args.embed_img:
            in_dim = 768
            dummy_input = torch.randn(2, 768).to(args.device)
            dataset = get_embedded_dataset(dataset, args)
            dataset_test = get_embedded_dataset(dataset_test, args)
        else:
            in_dim = 3*32*32
            dummy_input = torch.randn(2, 3, 32, 32).to(args.device)
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        out_dim = 200
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc
        )
    elif args.dataset == "cifar100":
        # make sure to use with --embed-img otherwise this data is too difficult for an MLP
        dataset = CIFAR100(data_path, download=True, train=True)
        dataset_test = CIFAR100(data_path, download=True, train=False)
        args.n_classes = 100
        args.class_inc = 10
        if args.embed_img:
            in_dim = 768
            dummy_input = torch.randn(2, 768).to(args.device)
            dataset = get_embedded_dataset(dataset, args)
            dataset_test = get_embedded_dataset(dataset_test, args)
        else:
            in_dim = 3*32*32
            dummy_input = torch.randn(2, 3, 32, 32).to(args.device)
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        out_dim = 100
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc
        )
    elif args.dataset == "cifar10":
        # make sure to use with --embed-img otherwise this data is too difficult for an MLP
        dataset = CIFAR10(data_path, download=True, train=True)
        dataset_test = CIFAR10(data_path, download=True, train=False)
        args.n_classes = 10
        args.class_inc = 2
        if args.embed_img:
            in_dim = 768
            dummy_input = torch.randn(2, 768).to(args.device)
            dataset = get_embedded_dataset(dataset, args)
            dataset_test = get_embedded_dataset(dataset_test, args)
        else:
            in_dim = 3*32*32
            dummy_input = torch.randn(2, 3, 32, 32).to(args.device)
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        out_dim = 10
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc
        )
    elif args.dataset == "synbols":
        # make sure to use with --embed-img otherwise this data is too difficult for an MLP
        dataset = Synbols(data_path, download=True, train=True)
        args.class_inc = 6
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        dummy_input = torch.randn(2, 3, 32, 32).to(args.device)
        if args.embed_img:
            in_dim = 768
        else:
            in_dim = 3*32*32
        out_dim = 48
        args.class_inc = 6
        dataset_test = Synbols(data_path, download=True, train=False)
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc
        )
    elif args.dataset == "core50":
        # make sure to use with --embed-img otherwise this data is too difficult for an MLP
        dataset = Core50(data_path, download=True, train=True)
        dataset_test = Core50(data_path, download=True, train=False)
        args.n_classes = 50
        args.class_inc = 10
        if args.embed_img:
            in_dim = 768
            dummy_input = torch.randn(2, 768).to(args.device)
            dataset = get_embedded_dataset(dataset, args)
            dataset_test = get_embedded_dataset(dataset_test, args)
        else:
            in_dim = 3*224*224
            dummy_input = torch.randn(2, 3, 32, 32).to(args.device)
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        out_dim = 50
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc 
        )
    elif args.dataset == "fmnist":
        dataset = FashionMNIST(data_path, download=True, train=True)
        args.class_inc = 2
        scenario = ClassIncremental(
            dataset,
            increment=args.class_inc,
        )
        dummy_input = torch.randn(2, 1, 28, 28).to(args.device)
        if args.embed_img:
            in_dim = 768
        else:
            in_dim = 1*28*28
        out_dim = 10
        dataset_test = FashionMNIST(data_path, download=True, train=False)
        test_scenario = ClassIncremental(
            dataset_test,
            increment=args.class_inc
        )
    elif args.dataset == "permmnist":
        ## make sure to use the same seed for both test and train otherwise it makes no sense
        dataset = MNIST(data_path, download=True, train=True)
        nb_tasks = args.n_tasks
        scenario = Permutations(cl_dataset=dataset, nb_tasks=nb_tasks, seed=args.seed, shared_label_space=True)
        dummy_input = torch.randn(2, 1, 28, 28).to(args.device)
        args.class_inc = 10
        if args.embed_img:
            in_dim = 768
        else:
            in_dim = 1*28*28
        out_dim = 10
        dataset_test = MNIST(data_path, download=True, train=False)
        test_scenario = Permutations(cl_dataset=dataset_test, nb_tasks=nb_tasks, seed=args.seed, shared_label_space=True)
    elif args.dataset == 'ruarobot':
        train_dataset, test_dataset = get_RUARobot_datasets(path=args.data_path, args=args, num_tasks=args.n_tasks)
        args.class_inc = 1
        scenario = ClassIncremental(
            train_dataset,
            increment= args.class_inc,
        )
        dummy_input = torch.randn(2, 384).to(args.device)
        in_dim = 384
        out_dim = 2
        test_scenario = ClassIncremental(
            test_dataset,
            increment=args.class_inc,
        )
        args.class_inc = 1
    elif args.dataset == 'medical':
        train_dataset, test_dataset = get_medical_datasets(path=args.data_path, args=args, num_tasks=args.n_tasks)
        args.class_inc = 1
        scenario = ClassIncremental(
            train_dataset,
            increment=args.class_inc,
        )
        dummy_input = torch.randn(2, 384).to(args.device)
        in_dim = 384
        out_dim = 2
        test_scenario = ClassIncremental(
            test_dataset,
            increment=args.class_inc,
        )
        args.class_inc = 2
    else:
        raise ValueError('Dataset {} not found'.format(args.dataset))
    logger.info(f"Number of classes: {scenario.nb_classes}.")
    logger.info(f"Number of tasks: {scenario.nb_tasks}.")
    args.n_classes = scenario.nb_classes
    args.n_tasks = scenario.nb_tasks
    
    if args.train_type == "intercontinet":
        from modular_intervalnet import IntervalNet
        args.model_func = IntervalNet.create_mlp
        model = args.model_func(input_size= in_dim, hidden_dim=args.hdim , output_classes=out_dim, heads=1, radius_multiplier=1.0, max_radius=args.max_radius)
        model.to(args.device)
    else:
        if args.model == "mlp_3":
            args.model_func = mlp_3layer#mlp_conv
            model = args.model_func(in_dim=in_dim, out_dim=out_dim, h_dim=args.hdim)#h_dim=args.hdim)
            model.to(args.device)
        elif args.model == 'mlp_2':
            args.model_func = mlp_2layer
            model = args.model_func(in_dim=in_dim, out_dim=out_dim, h_dim=args.hdim)#h_dim=args.hdim)
            model.to(args.device)
        elif args.model == "resnet18":
            from torchvision.models import resnet18
            model = resnet18(num_classes=10)
            model.to(args.device)
        else:
            raise ValueError("Model {} not found".format(args.model))

    model.train()

    model_ori = model
    if args.train_type =='cerce' or args.track_buffer:
        define_bounds(model)
        if args.loss_fusion:
            model_wrapped = ModelWrapperLoss(model)
            dummy_input_w = (dummy_input, torch.ones(dummy_input.size(0)).long().to(args.device), torch.tensor([args.n_classes]).long().to(args.device))
            model_wrapped = ModelWrapper(model_wrapped, dummy_input_w, args)
        model = ModelWrapper(model, dummy_input, args)
        set_eps(model, args.gamma)
    elif args.loss_fusion:
        raise ValueError("Loss fusion without CerCE is not supported!")

    model.train()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum = args.momentum)
    else:
        raise ValueError("Optimizer not found")


    train_verified(model, model_ori, optimizer, scenario, args, model_wrapped if args.loss_fusion else None)

    return evaluate_model(model_ori, test_scenario, args)