import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from dataset import UnLearningData
import numpy as np
from utils import *
import time


def UnlearnerLoss(output, labels, full_teacher_logits, unlearn_teacher_logits, KL_temperature):
    labels = torch.unsqueeze(labels, dim = 1)
    
    f_teacher_out = F.softmax(full_teacher_logits / KL_temperature, dim=1)
    u_teacher_out = F.softmax(unlearn_teacher_logits / KL_temperature, dim=1)

    # label 1 means forget sample
    # label 0 means retain sample
    overall_teacher_out = labels * u_teacher_out + (1-labels)*f_teacher_out
    student_out = F.log_softmax(output / KL_temperature, dim=1)
    return F.kl_div(student_out, overall_teacher_out)

def unlearning_step(model, unlearning_teacher, full_trained_teacher, unlearn_data_loader, optimizer, 
            device, KL_temperature, trans=False):
    losses = []
    for batch in unlearn_data_loader:
        x, y = batch
        
        if trans:
            x = {k: v.to(device) for k, v in x.items()}
        else:
            x = x.to(device)
         
        y = y.to(device)

        with torch.no_grad():
            if trans:
                full_teacher_logits = full_trained_teacher(**x)
                unlearn_teacher_logits = unlearning_teacher(**x)
            else:
                full_teacher_logits = full_trained_teacher(x)
                unlearn_teacher_logits = unlearning_teacher(x)

        if trans:
            output = model(**x)
        else:
            output = model(x)

        optimizer.zero_grad()
        loss = UnlearnerLoss(output = output, labels=y, full_teacher_logits=full_teacher_logits, 
                unlearn_teacher_logits=unlearn_teacher_logits, KL_temperature=KL_temperature)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
    return np.mean(losses)


def fit_one_unlearning_cycle(epochs,  model, train_loader, val_loader, lr, device, trans=False):
    history = []
    
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)

    t_start = time.time()
    for epoch in range(epochs): 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device, trans)
            loss.backward()
            train_losses.append(loss.detach().cpu())
            
            optimizer.step()
            optimizer.zero_grad()
            
            lrs.append(get_lr(optimizer))
            
        
        result = evaluate(model, val_loader, device, trans)            
        result['train_loss'] = torch.stack(train_losses).mean()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
    total_time = time.time() - t_start
    return history, total_time

def blindspot_unlearner(model, unlearning_teacher, full_trained_teacher, retain_data, forget_data, epochs = 10,
                optimizer = 'adam', lr = 0.01, batch_size = 256, num_workers = 32, 
                device = 'cuda', KL_temperature = 1, trans=False):
    # creating the unlearning dataset.
    unlearning_data = UnLearningData(forget_data=forget_data, retain_data=retain_data)
    unlearning_loader = DataLoader(unlearning_data, batch_size = batch_size, shuffle=True, 
                            num_workers=num_workers, pin_memory=True)

    unlearning_teacher.eval()
    full_trained_teacher.eval()
    optimizer = optimizer
    if optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    else:
        # if optimizer is not a valid string, then assuming it as a function to return optimizer
        optimizer = optimizer#(model.parameters())

    t_start = time.time()
    for epoch in range(epochs):
        loss = unlearning_step(model = model, unlearning_teacher= unlearning_teacher, 
                        full_trained_teacher=full_trained_teacher, unlearn_data_loader=unlearning_loader, 
                        optimizer=optimizer, device=device, KL_temperature=KL_temperature, trans=trans)
        print("Epoch {} Unlearning Loss {}".format(epoch+1, loss))
    total_time = time.time() - t_start
    return total_time
        
   
class UNSIR_noise(torch.nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)
        
    def forward(self):
        return self.noise
    
def UNSIR_noise_train(
    noise,
    model,
    forget_class_label,    # int or list[int]
    num_epochs,
    noise_batch_size,
    device='cuda', trans=False
):
    # normalize to a list of ints
    if isinstance(forget_class_label, int):
        labels_list = [forget_class_label]
    else:
        labels_list = list(forget_class_label)

    noise = noise.to(device)
    opt = torch.optim.Adam(noise.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        inputs = noise()                        # (B, C, H, W)
        if inputs.device != torch.device(device):
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)

        # pick one forget‐label per sample
        chosen = np.random.choice(labels_list, size=inputs.size(0))
        labels = torch.from_numpy(chosen).long().to(device)

        if trans:
            outputs = model(**inputs)
        else:
            outputs = model(inputs)
        
        loss = (
            -F.cross_entropy(outputs, labels)
            + 0.1 * inputs.pow(2).flatten(1).sum(1).mean()
        )

        opt.zero_grad()
        loss.backward()
        opt.step()

        if epoch % 25 == 0:
            print(f"[epoch {epoch:3d}] loss = {loss.item():.4f}")

    return noise


def UNSIR_create_noisy_loader(
    noise,
    forget_class_label,    # int or list[int]
    retain_loader,         # a DataLoader yielding (imgs, labels)
    batch_size,
    num_noise_batches=80,
    device='cuda', trans=False
):
    # normalize to a list of ints
    if isinstance(forget_class_label, int):
        labels_list = [forget_class_label]
    else:
        labels_list = list(forget_class_label)

    noisy_data = []

    # 1) generate your noise examples
    for _ in range(num_noise_batches):
        batch = noise().detach().cpu()          # assume noise() returns a Tensor
        B = batch.size(0)
        chosen = np.random.choice(labels_list, size=B)
        labels = torch.from_numpy(chosen).long()
        for i in range(B):
            noisy_data.append((batch[i], labels[i], labels[i]))

    # 2) append your retained‐data examples
    for imgs, lbls in retain_loader:
        imgs = imgs.cpu()
        lbls = lbls.cpu()
        B = imgs.size(0)
        for i in range(B):
            noisy_data.append((imgs[i], lbls[i], lbls[i]))

    return DataLoader(
        noisy_data,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=32,
    )
