from utils.buffer import Buffer
import torch
import numpy as np
from continuum.tasks import split_train_val
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import cvxpy as cp
from auto_LiRPA import BoundedModule
from collections import defaultdict
import sys 
import torchmetrics
from model import set_eps, clip_model_grads
from evaluate import *
from auto_LiRPA.utils import get_spec_matrix
from utils.training import *
from utils.gem import *
import time 

def train_verified(model, model_ori, optimizer, scenario, args, model_wrapped=None):
    batch_size = args.batch_size
    if args.train_type == 'joint':
        dataset = scenario.cl_dataset
        train_taskset = dataset.to_taskset()
        model.num_classes = args.n_classes
        model_ori.num_classes = args.n_classes
        train_loader = DataLoader(train_taskset, batch_size=batch_size, shuffle=True)
        pbar = tqdm(total = (args.epochs * len(train_loader)), file=sys.stdout)
        args.metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=model.num_classes).to(args.device)
        for epoch in range(args.epochs):
            pbar.set_description(f"Epoch {epoch + 1} Running Acc: {0 if epoch == 0 else args.metric.compute().item()}")
            for x, y, t in train_loader:
                pbar.update(1)
                optimizer.zero_grad()
                x, y = x.cuda(), y.cuda()
                update_naive(model, optimizer, x, y, args)
    else:
        buffer = Buffer(args.buffer_size)
        big_mat_l = []
        big_bias_l = []
        val_loaders = []
        last_chkpnt = []
        samples_list = []
        max_accs = []
        for task_id, train_taskset in enumerate(scenario):
            training_times = []
            train_taskset, val_taskset = split_train_val(train_taskset, val_split=0.1)
            train_loader = DataLoader(train_taskset, batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(val_taskset, batch_size=batch_size, shuffle=True)
            val_loaders.append(val_loader)
            args.g_list = None
            model.num_classes += args.class_inc
            model_ori.num_classes = model.num_classes
            if model_wrapped is not None:
                model_wrapped.num_classes = model.num_classes
            model.num_classes = min(args.n_classes, model.num_classes)
            pbar =  tqdm(total = (args.epochs * len(train_loader)), file=sys.stdout)
            args.metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=scenario.nb_classes).to(args.device)
            args.buffer_acc = torchmetrics.aggregation.MeanMetric().to(args.device)
            args.num_cert = torchmetrics.aggregation.MeanMetric().to(args.device)

            begin_task(model, model_ori, optimizer, train_loader, task_id, args)
            if args.train_type=='intercontinet':
                model.before_task(task_id)
                # optimizer = create_intervalnet_optimizer(model, task_id, args)
            for epoch in range(args.epochs):
                if args.train_type=='intercontinet':
                    model.before_epoch(epoch, args.epochs, args.contraction_epochs)
                    update_intervalnet_lr(optimizer, model.mode, task_id, args)
                elif args.lpr:
                    from lpr_plugin import LPRPlugin
                    args.lpr_plugin = LPRPlugin(buffer=buffer, args=args)
                
                pbar.set_description(f"Task {task_id + 1} Epoch {epoch + 1} Running Acc: {0 if epoch == 0 else args.metric.compute().item()}")
                if args.train_type == 'lwf': # fix order of dataset for logit pairing loss
                    import random
                    args.order_seed = random.randint(0, 10000)
                    torch.manual_seed(args.order_seed)
                time_start = time.time()
                for i, (x, y, t) in enumerate(train_loader):
                    pbar.update(1)
                    optimizer.zero_grad()
                    x, y = x.cuda(), y.cuda()
                    
                    if args.train_type == "naive":
                        update_naive(model, optimizer, x, y, args)
                    elif args.train_type == "cerce":
                        update_cerce(model_wrapped if args.loss_fusion else model, model_ori, optimizer, buffer, x, y, args)
                    elif args.train_type == "er":
                        update_er(model, optimizer, buffer, x, y, args)
                    elif args.train_type == "ewc":
                        update_ewc(model, optimizer, x, y, args)
                    elif args.train_type == "lwf":
                        update_lwf(model, optimizer, x, y, i, args)
                    elif args.train_type == "agem":
                        update_agem(model_ori, optimizer, buffer, x, y, args)
                    elif args.train_type == 'intercontinet':
                        update_intercontinet(model, optimizer, x, y, task_id, args)
                    else:
                        raise ValueError(f"Unknown training type: {args.train_type}")
                    with torch.no_grad():
                        update_buffer(model, model_ori, optimizer, x, y, buffer, args)
                training_times.append(time.time() - time_start)
                if args.track_buffer and not args.train_type == "intercontinet":
                    if not buffer.is_empty():
                        acc = 0
                        cert = 0
                        if not args.dark:
                            x, y = buffer.get_data(args.buffer_size, device="cuda")
                        else:
                            x, y, _ = buffer.get_data(args.buffer_size, device="cuda")
                        x, y = x.cuda(), y.cuda()
                        num_total = y.size(0)
                        for i in range(num_total//args.mini_batchsize):
                            low = i*args.mini_batchsize
                            high = min((i+1)*args.mini_batchsize, x.size(0))
                            x_batch = x[low:high]
                            y_batch = y[low:high]
                            pred = model_ori(x_batch)[:, :model.num_classes].detach()
                            acc += torch.sum((pred.argmax(1) == y_batch).float()).item()
                            set_eps(model, args.gamma)
                            # if args.loss_fusion:
                            #     C = None
                            #     inp = (x_batch, y_batch, torch.tensor([model.num_classes]).long().to(args.device))
                            #     _, ub = model.compute_bounds(x=inp, C= C, method=args.lirpa_method, bound_lower=False)
                            #     num = (ub <= 1e-4).float().sum()
                            #     print(ub)
                            # else:
                            C = get_spec_matrix(x_batch, y_batch, args.n_classes)
                            lb, _ = model.compute_bounds(x=(x_batch,), C= C, method=args.lirpa_method, bound_upper=False)
                            lb = lb[:, :model.num_classes - 1]
                            num = torch.sum((torch.maximum(torch.zeros_like(lb), lb) > 0).float(), dim=1).detach() 
                            num = (num == torch.ones_like(num)*(model.num_classes - 1)).float().sum().item()
                            cert += num
                        args.buffer_acc.update(acc/num_total)
                        args.num_cert.update(cert/num_total)

                        if args.wandb:
                            import wandb
                            wandb.log({"Buffer Acc": acc/num_total})
                            wandb.log({"Ratio Cert": cert/num_total})
            logger.info(f"Train time per epoch: {np.mean(training_times)}s")    
            end_task(model, optimizer, train_loader, buffer, args)
            pbar.close()
            if not buffer.is_empty() and not args.train_type == 'intercontinet':
                if not args.dark:
                    x, y = buffer.get_data(args.buffer_size, device="cuda")
                else:
                    x, y, _ = buffer.get_data(args.buffer_size, device="cuda")
                x, y = x.cuda(), y.cuda()
                pred = model_ori(x)[:, :model.num_classes]
                acc_buffer = (pred.argmax(1) == y).float().mean().item()
                logger.info(f"Task {task_id + 1}: Buffer accuracy: {acc_buffer}")

            if args.eval_current:
                eval_current_task(model_ori, val_loader, args)
            else:
                for i, val_loader in enumerate(val_loaders):
                    acc = eval_current_task(model_ori, val_loader, i + 1, args)
                    if len(max_accs) < i + 1:
                        max_accs.append(acc)
                    else:
                        max_accs[i] = max(max_accs[i], acc)
        ### calc forgetting
        forget_vals = []
        for i, val_loader in enumerate(val_loaders):
            acc = eval_current_task(model_ori, val_loader, i + 1, args)
            forget_vals.append( max_accs[i] - acc)
        if hasattr(args, 'forget_vals'):
            args.forget_vals.append(np.mean(forget_vals))
        else:
            args.forget_vals = [np.mean(forget_vals)]
        logger.info(f"Final Forgetting: {args.forget_vals[-1]}")
        if hasattr(args, 'logits'):
            delattr(args, 'logits')
        if hasattr(args, 'checkpoint'):            
            delattr(args, 'checkpoint')
        
def update_agem(model, opt, buffer, x, y, args):
    """Update model using A-GEM approach"""
    opt.zero_grad()
    p = model(x)
    loss = F.cross_entropy(p, y)
    loss.backward()

    if not buffer.is_empty():
        store_grad(model.parameters, args.grad_xy, args.grad_dims)

        buf_inputs, buf_labels = buffer.get_data(args.mini_batchsize, device=args.device)
        model.zero_grad()
        buf_outputs = model(buf_inputs)
        penalty = F.cross_entropy(buf_outputs, buf_labels)
        penalty.backward()
        store_grad(model.parameters, args.grad_er, args.grad_dims)

        dot_prod = torch.dot(args.grad_xy, args.grad_er)
        if dot_prod.item() < 0:
            g_tilde = project(gxy=args.grad_xy, ger=args.grad_er)
            overwrite_grad(model.parameters, g_tilde, args.grad_dims)
        else:
            overwrite_grad(model.parameters, args.grad_xy, args.grad_dims)

    opt.step()

    return loss.item()
        

def update_cerce(model, model_ori, opt, buffer, x, y, args):
    if not buffer.is_empty():
        if args.all_samples:
            buf_x, buf_y = buffer.get_data(args.batch_size, device="cuda")
            x_ce, y_ce = torch.cat((x, buf_x)), torch.cat((y, buf_y))
        else:
            buf_x, buf_y = buffer.get_data(args.mini_batchsize, device="cuda")
            x_ce, y_ce = x, y
    else:
        x_ce, y_ce = x, y
        buf_x, buf_y = x, y
    set_eps(model, args.gamma)
    opt.zero_grad()
    if args.loss_fusion:
        C = None
        inp = (buf_x, buf_y, torch.tensor([model.num_classes]).long().to(args.device))
        _, ub = model.compute_bounds(x=inp, C= C, method=args.lirpa_method, bound_lower=False)
        # ilb, iub = model(method_opt="compute_bounds", x=inp, C=C, method="IBP", final_node_name=None, no_replicas=True)
        # lb, ub = model(method_opt="compute_bounds", C=C, method="CROWN-IBP",
        #                 bound_lower=False, bound_upper=True, final_node_name=None, average_A=True, no_replicas=True)
        loss_lag = torch.mean(torch.log(ub))
        # inp_curr = (x, y, torch.tensor([model.num_classes]).long().to(args.device))
        pred = model_ori(x_ce)[:, :model.num_classes]
        args.metric(pred.detach().argmax(1), y_ce.detach())
        loss_ce = F.cross_entropy(pred, y_ce)
        loss = loss_ce + args.lam * loss_lag
    else:
        C = get_spec_matrix(buf_x, buf_y, args.n_classes)
        # print(buf_x.shape )
        lb, _ = model.compute_bounds(x=(buf_x,), C= C, method=args.lirpa_method, bound_upper=False)
        lb = lb[:, :model.num_classes - 1]
        loss_lag = torch.sum(torch.maximum(torch.zeros_like(lb), -lb))
        pred = model(x_ce)[:, :model.num_classes]
        args.metric(pred.detach().argmax(1), y_ce.detach())
        loss_ce = F.cross_entropy(pred, y_ce)
        loss =  loss_ce + args.lam * loss_lag
    
    loss.backward()
    if args.grad_clip:
        clip_model_grads(model, args.gamma)
    opt.step()
   
        # update_buffer(model, model_ori, opt, x, y, buffer, args, labels=pred.argmax(1)[:y.size(0)])

def update_buffer(model, model_ori, opt, x, y, buffer,  args, labels=None):
    set_eps(model,args.gamma)
    opt.zero_grad()
    x, y = x.cuda(), y.cuda()
    if args.dark:
        log = model(x)
    if args.buffer_select == 'bound' and model.num_classes > 1:
        if args.loss_fusion:
            # inp = (x, y, torch.tensor([model.num_classes]).long().to(args.device))
            #TODO not correct
            C= None
            pred = model_ori(x)[:, :model.num_classes]
            x_select = x[pred.argmax(1) == y]
            y_select = y[pred.argmax(1) == y]
        else: 
            inp = (x,)
            C = get_spec_matrix(x, y, args.n_classes)
            lb, ub = model.compute_bounds(x=inp, C= C, method=args.lirpa_method)
            lb = lb[:, :model.num_classes - 1]
            x_select = x[lb.min(dim=1)[0] >= 0]
            y_select = y[lb.min(dim=1)[0] >= 0]
        if args.dark:
            log = log[lb.min(dim=1)[0] >= 0]
    elif args.buffer_select == 'correct':
        if labels is None:
            if args.loss_fusion:
                C= None
                pred = model_ori(x)[:, :model.num_classes]
            else: 
                pred = model(x)
            labels = pred.argmax(1)
        x_select, y_select = x[labels==y], y[labels==y]
        if args.dark:
            log = log[labels==y]
    elif args.buffer_select == 'rand' or model.num_classes == 1:
        x_select, y_select = x, y
    if args.dark:
        buffer.add_data(examples=x_select, labels=y_select, logits=log.detach().cpu())
    else:
        buffer.add_data(examples=x_select, labels=y_select)

def update_naive(model, opt, x, y, args):
    pred = model(x)[:, :model.num_classes]
    args.metric(pred.detach().argmax(1), y)
    loss = F.cross_entropy(pred, y)
    loss.backward()
    opt.step()

def update_intercontinet(model, opt, x, y, task_id, args):
    task_labels = torch.full((x.size(0),), task_id, dtype=torch.long, device=x.device)
    
    loss = model.compute_loss(x, y)#, task_labels=task_labels)
    
    opt.zero_grad()
    loss.backward()
    opt.step()    
    model.after_step() 
    
    with torch.no_grad():
        preds = model.predict_classes(x)#, task_labels=task_labels)
        args.metric(preds, y)

def update_er(model, opt, buffer, x, y, args):
    if not buffer.is_empty():
        if not args.dark:
            buf_x, buf_y = buffer.get_data(args.mini_batchsize, device=args.device)
            x_ce, y_ce = torch.cat((x, buf_x)), torch.cat((y, buf_y))
        else:
            buf_x, buf_y, buf_log = buffer.get_data(args.mini_batchsize, device=args.device)
            x_ce, y_ce = x, y
    else:
        x_ce, y_ce = x, y
    pred = model(x_ce)[:, :model.num_classes]
    # args.metric(ped, y_ce)
    loss = F.cross_entropy(pred, y_ce)
    if args.dark and not buffer.is_empty():
        out = model(buf_x)[:, :model.num_classes]
        loss += args.alpha_d * F.mse_loss(out, buf_log[:, :model.num_classes])

        buf_x, buf_y, _ = buffer.get_data(args.mini_batchsize, device=args.device)
        pred = model(buf_x)[:, :model.num_classes]
        loss += args.beta_d * F.cross_entropy(pred, buf_y)
    if args.lpr:
      args.lpr_plugin.before_backward(model, args.device)    
    loss.backward()
    if args.lpr:
      args.lpr_plugin.after_backward(model)
    opt.step()





def update_ewc(model, opt, x, y, args):
    opt.zero_grad()
    outputs = model(x)[:, :model.num_classes]
    if hasattr(args, 'checkpoint'):
        set_grads(model, get_penalty_grads(model, args))
    loss = F.cross_entropy(outputs, y)
    assert not torch.isnan(loss)
    loss.backward()
    opt.step()

def update_lwf(model, opt, x, y, ind, args):
    opt.zero_grad()
    outputs = model(x)

    loss = F.cross_entropy(outputs, y)
    if hasattr(args, 'logits'):
        # print(torch.autograd.grad(loss, model.parameters(), retain_graph=True)[0].norm())
        logits = args.logits[ind]
        if model.num_classes - args.class_inc > 1:
            loss_kl = args.alpha *  modified_kl_div(smooth(F.softmax(logits[:, :model.num_classes - args.class_inc], dim=1).cuda(), args.softmax_temp, 1),
                                                        smooth(F.softmax(outputs[:, :model.num_classes - args.class_inc], dim=1), args.softmax_temp, 1)) 
        else:               
            loss_kl = args.alpha * (logits[:, 0].cuda() - outputs[:, 0]).norm()                                  
        loss += loss_kl
    loss.backward()
    opt.step()
