import torch
import numpy as np
import math
from scipy import optimize
import pdb
import torchvision.models
import models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import time


def f(x, a, b, c, d):
    """Function for the optimization solver."""
    return np.sum(a * b * np.exp(-1 * x / c)) - d

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(100 - correct_k.mul_(100.0 / batch_size))
        return res

def get_class_indices(dataset, desired_classes):
    """
    Returns a list of indices for samples in the dataset that belong to the desired classes.
    
    Parameters:
    - dataset: The dataset to filter.
    - desired_classes: List of class labels to include.
    
    Returns:
    - List of indices corresponding to the desired classes.
    """
    indices = [i for i, (_, label) in enumerate(dataset) if label in desired_classes]
    return indices


def opt_solver(probs, target_distb, num_iter=10, th=1e-5, num_newton=30):
    """Optimization solver to refine predictions."""
    weights = torch.ones(probs.shape[0])
    N, K = probs.size(0), probs.size(1)

    A, w = probs.numpy(), weights.numpy()
    lam = np.ones(N)*2
    nu = np.ones(K)
    r = np.ones(N)
    c = target_distb.numpy()
    A_e = A / math.e
    X = np.exp(-1 * lam / w)
    Y = np.exp(-1 * nu.reshape(1, -1) / w.reshape(-1, 1))
    prev_Y = np.zeros(K)
    X_t, Y_t = X, Y

    for n in range(num_iter):
        denom = np.sum(A_e * Y_t, 1)
        X_t = r / denom
        Y_t = np.zeros(K)
        for i in range(K):
            Y_t[i] = optimize.newton(f, prev_Y[i], maxiter=num_newton, args=(A_e[:, i], X_t, w, c[i]), tol=th)
        prev_Y = Y_t
        Y_t = np.exp(-1 * Y_t.reshape(1, -1) / w.reshape(-1, 1))

    denom = np.sum(A_e * Y_t, 1)
    X_t = r / denom
    M = torch.Tensor(A_e * X_t.reshape(-1, 1) * Y_t)

    return M

def load_tensors():
    """Loads the tensor data from files."""
    # Load outputs and targets
    tensor1 = torch.load('train_cifar5_resnet_before_all_outputs_tensor.pt').to('cuda')
    tensor2 = torch.load('train_cifar5_resnet_before_all_targets_tensor.pt').to('cuda')
    tensor3 = torch.load('forget_cifar5_resnet_before_all_outputs_tensor.pt').to('cuda')
    tensor4 = torch.load('forget_cifar5_resnet_before_all_targets_tensor.pt').to('cuda')
    
    # Load input tensors
    inputs_tensor1 = torch.load('train_cifar5_resnet_before_all_inputs_tensor.pt').to('cuda')
    inputs_tensor3 = torch.load('forget_cifar5_resnet_before_all_inputs_tensor.pt').to('cuda')
    
    # Handle NaNs if any
    tensor1[torch.isnan(tensor1)] = 1.0
    tensor3[torch.isnan(tensor3)] = 1.0
    
    print(tensor1.shape, tensor2.shape, tensor3.shape, tensor4.shape)
   
    return tensor1, tensor2, tensor3, tensor4, inputs_tensor1, inputs_tensor3

def refine_predictions(tensor1, tensor2, tensor3, tensor4, inputs_tensor1, inputs_tensor3, model):
    """Refines predictions using the optimization solver and updates model weights."""
    
    start = time.time()
    probabilities_tensor1 = torch.softmax(tensor1, dim=1)
    probabilities_tensor3 = torch.softmax(tensor3, dim=1)

    retain_class_distribution = probabilities_tensor1.mean(dim=0)
    forget_class_distribution = probabilities_tensor3.mean(dim=0) 
    
    probabilities_sum = probabilities_tensor1.sum(dim=0) + probabilities_tensor3.sum(dim=0)
    
    random_values_tensor = torch.rand_like(probabilities_tensor3)
    
    #random distribution
    #random_probabilities_tensor = random_values_tensor / random_values_tensor.sum(dim=1, keepdim=True)
    

    num_classes = probabilities_tensor3.size(1)

    #uniform distribution
    random_probabilities_tensor = torch.ones_like(probabilities_tensor3) / num_classes


    #random_probabilities_tensor = random_probabilities_tensor * (retain_class_distribution + forget_class_distribution)
    #random_probabilities_tensor = random_probabilities_tensor / random_probabilities_tensor.sum(dim=1, keepdim=True)  

    

    
  
    probabilities = torch.cat((probabilities_tensor1, random_probabilities_tensor), dim=0)
    target_tensor = torch.cat((tensor2, tensor4), dim=0)
    
    target_distb = probabilities.sum(dim=0)
    refined_prediction = opt_solver(probabilities.cpu(), probabilities_sum.cpu())
    refined_prediction = refined_prediction.to('cuda')
   
    combined_inputs = torch.cat((inputs_tensor1, inputs_tensor3), dim=0).to('cuda')
    refined_labels = refined_prediction.detach().to('cuda')
    
  
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = torch.nn.KLDivLoss(reduction='batchmean')
    
   
    model.train()
    
    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(combined_inputs, refined_labels)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    
    forget_set_errors = []
    num_epochs = 60 # Adjust as needed
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for inputs_batch, targets_batch in dataloader:
            inputs_batch = inputs_batch.to('cuda')
            targets_batch = targets_batch.to('cuda')

            optimizer.zero_grad()
            outputs = model(inputs_batch)
            outputs_log_softmax = torch.log_softmax(outputs, dim=1)
            loss = criterion(outputs_log_softmax, targets_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        model.eval()
        with torch.no_grad():
            forget_outputs = model(inputs_tensor3)
            forget_error = accuracy(forget_outputs, tensor4)[0].item()
            forget_set_errors.append(forget_error)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}')
    

    model.eval()
    

    end_time = time.time()
    elapsed_time = end_time - start
    print(f"\nTotal Execution Time: {elapsed_time:.2f} seconds")
    with torch.no_grad():
        updated_outputs = model(combined_inputs)
        updated_refined_prediction = torch.softmax(updated_outputs, dim=1)
    torch.save(model.state_dict(), 'complete_weights.pth')
    print(forget_set_errors)


    
    return updated_refined_prediction

def evaluate_accuracy(tensor1, tensor2, tensor3, tensor4, refined_prediction):
    """Evaluates and prints the accuracy before and after refinement."""
    total_samples_tensor1 = tensor1.shape[0]
    
    print("Before refinement:")
    print("Retain set error rate:", accuracy(tensor1, tensor2)[0].item())
    print("Forget set error rate:", accuracy(tensor3, tensor4)[0].item())
    
    print("\nAfter refinement and weight update:")
    refined_prediction = refined_prediction.to('cuda')
    print("Retain set error rate:", accuracy(refined_prediction[:total_samples_tensor1, :], tensor2)[0].item())
    print("Forget set error rate:", accuracy(refined_prediction[total_samples_tensor1:, :], tensor4)[0].item())

def main(model):
    """Main function to run the defined operations."""
    #start_time = time.time()  # Start timer
    tensor1, tensor2, tensor3, tensor4, inputs_tensor1, inputs_tensor3 = load_tensors()
    updated_refined_prediction = refine_predictions(
        tensor1, tensor2, tensor3, tensor4,
        inputs_tensor1, inputs_tensor3, model)

    evaluate_accuracy(tensor1, tensor2, tensor3, tensor4, updated_refined_prediction)
    
    #elapsed_time = end_time - start_time
    #print(f"\nTotal Execution Time: {elapsed_time:.2f} seconds")

if __name__ == "__main__":
    # Instantiate your model here
    model = models.get_model('resnet', num_classes=5, filters_percentage = 0.4)
    #model = models.get_model('allcnn', num_classes=10, filters_percentage = 1.0)    
    model = model.to('cuda')

    model.load_state_dict(torch.load('final_model_weights.pth'))

    main(model)
    

    
