import torch
from torch.utils.data import DataLoader
from continuum.datasets import InMemoryDataset
from tqdm import tqdm
from continuum import ClassIncremental
from continuum.tasks import TaskType
import torch.nn.functional as F
import numpy as np

def set_grads(model, grads):
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.grad = grads[:param.numel()].view(param.shape)
            grads = grads[param.numel():]

def get_params(model):
    params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            params.append(param.detach().clone().view(-1))
    params = torch.cat(params)
    return params

def get_grads(model):
    grads = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            grads.append(param.grad.detach().clone().view(-1))
    grads = torch.cat(grads)
    return grads

def begin_task(model, model_ori, opt, loader, task_id, args):
    if args.train_type == 'lwf':
        model.eval()
        if task_id > 0:

            logits = []
            with torch.no_grad():
                torch.manual_seed(args.order_seed)
                for i, (x, y, t) in enumerate(loader):
                    x, y = x.cuda(), y.cuda()
                    log = model(x).cpu()
                    logits.append(log)
            args.logits = logits
        model.train()
    elif args.train_type == 'agem':
        args.grad_dims = []
        for param in model_ori.parameters():
            args.grad_dims.append(param.data.numel())
        args.grad_xy = torch.Tensor(np.sum(args.grad_dims)).to(args.device)
        args.grad_er = torch.Tensor(np.sum(args.grad_dims)).to(args.device)

def end_task(model, opt, loader, buffer, args):
    if args.train_type == 'ewc':
        fish = torch.zeros_like(get_params(model))

        for x, y, t in loader:
            x, y = x.cuda(), y.cuda()
            opt.zero_grad()
            output = model(x)
            loss = - F.nll_loss(F.log_softmax(output), y,
                                reduction='none')
            exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
            loss = torch.mean(loss)
            loss.backward()
            fish += exp_cond_prob * get_grads(model) ** 2

        fish /= (len(loader) * args.batch_size)

        if not hasattr(args, 'fish'):
            args.fish = fish
        else:
            args.fish *= args.ewc_gamma
            args.fish += fish

        args.checkpoint = get_params(model)
    elif args.train_type == 'agem':
        samples_per_task = args.buffer_size // args.n_tasks
        cur_y, cur_x = next(iter(loader))[1:]
        buffer.add_data(
            examples=cur_x.to(args.device),
            labels=cur_y.to(args.device)
        )

def get_penalty_grads(model, args):
    return args.ewc_lambda * 2 * args.fish * (get_params(model) - args.checkpoint)

def smooth(logits, temp, dim):
    log = logits ** (1 / temp)
    return log / torch.sum(log, dim).unsqueeze(1)


def modified_kl_div(old, new):
    return -torch.mean(torch.sum(old * torch.log(new), 1))

def get_embedded_img(x, args):
    x = args.embed_transform(x)
    with torch.no_grad():
        x = args.embedder._process_input(x)
        n = x.shape[0]
        batch_class_token = args.embedder.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)
        tokens = args.embedder.encoder(x)  # shape: [1, seq_len, hidden_dim]
        x = tokens[:, 0] 
    return x

def get_embedded_dataset(dataset, args):
    Xs = []
    Ys = []
    scenario = ClassIncremental(
            dataset,
            increment=args.n_classes,
        )
    with torch.no_grad():
        for task_id, taskset in enumerate(scenario):
            loader = DataLoader(taskset, batch_size=args.test_batch_size, shuffle=True)
            for x,y,_ in tqdm(loader, desc='Embedding Dataset...'):
                x = get_embedded_img(x.cuda(), args).detach().cpu()
                if args.train_type == 'intercontinet':
                    x = F.relu(x)
                Xs.append(x)
                Ys.append(y)
    Xs = torch.cat(Xs, dim=0).numpy()
    Ys = torch.cat(Ys, dim=0).numpy()
    dataset = InMemoryDataset(x=Xs, y=Ys, data_type=TaskType.TENSOR)
    return dataset


def create_intervalnet_optimizer(model, task_id, args):
    """Create optimizer with appropriate learning rate for IntervalNet"""
    if task_id == 0:
        # First task: standard learning rate
        lr = getattr(args, 'lr', 0.001)
    else:
        # Subsequent tasks: lower learning rate for stability
        lr = getattr(args, 'lr', 0.001) * 0.1
    
    return torch.optim.SGD(model.parameters(), lr=lr)

def update_intervalnet_lr(optimizer, mode, task_id, args):
    """Update learning rate based on training phase"""
    base_lr = getattr(args, 'lr', 0.001)
    
    # Adjust base learning rate for later tasks
    task_factor = 0.1 ** (task_id // 3)
    
    if hasattr(mode, 'value'):  # Handle Mode enum
        mode_val = mode.value
    else:
        mode_val = mode
    
    if mode_val == 2:  # CONTRACTION_SHIFT
        lr = args.center_lr
    elif mode_val == 3:  # CONTRACTION_SCALE
        lr = args.radii_lr
    else:
        lr = args.lr

    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr