import os
import copy
import pickle
import torch
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
from torch.utils.data import DataLoader, Subset
"""
The following code is adapted from https://github.com/AdityaGolatkar/SelectiveForgetting/blob/master/Forgetting.ipynb
"""


def vectorize_params(model):
    param = []
    for p in model.parameters():
        param.append(p.data.view(-1).cpu().numpy())
    return np.concatenate(param)

def delta_w_utils(model_init,dataloader,name='complete'):
    model_init.train()
    dataloader = DataLoader(dataloader.dataset, batch_size=1, shuffle=False)
    G_list = []
    f0_minus_y = []
    for idx, batch in tqdm(enumerate(dataloader), desc="Getting delta_w"):
        batch = [tensor.to(next(model_init.parameters()).device) for tensor in batch]
        input, target = batch
        target = target.cpu().detach().numpy()
        output = model_init(input)
        num_class = output.shape[-1]
        G_sample=[]
        
        # get gradients for each class
        for cls in range(num_class):
            grads = torch.autograd.grad(output[0,cls],model_init.parameters(),retain_graph=True)
            grads = np.concatenate([g.view(-1).cpu().numpy() for g in grads])
            G_sample.append(grads)
            G_list.append(grads)

        p = torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().transpose()
        p[target] -= 1
        f0_y_update = copy.deepcopy(p)
        f0_minus_y.append(f0_y_update)

    return np.stack(G_list).transpose(),np.vstack(f0_minus_y)

def NIP(v1,v2):
    nip = (np.inner(v1/np.linalg.norm(v1),v2/np.linalg.norm(v2)))
    print(nip)
    return nip

def get_delta_w_dict(delta_w,model):
    # Give normalized delta_w
    delta_w_dict = OrderedDict()
    params_visited = 0
    for k,p in model.named_parameters():
        num_params = np.prod(list(p.shape))
        update_params = delta_w[params_visited:params_visited+num_params]
        delta_w_dict[k] = torch.Tensor(update_params).view_as(p)
        params_visited+=num_params
    return delta_w_dict

def ntk_init(config, model):
    model_init = copy.deepcopy(model)
    with open(config.model_path, "rb") as f:
        checkpoint = torch.load(f)
        model_init.load_state_dict(checkpoint["init_state_dict"])
    return model_init

def ntk(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    temp_dir = os.path.join(config.save_dir, "NTK_temp")
    print(f"NTK cached will be saved to {temp_dir}")
    os.makedirs(temp_dir, exist_ok=True)

    forget_loader = DataLoader(forget_set, batch_size=1, shuffle=True)
    # since retain_loader is created from Subset, we need to sample indices from retain_set.indices
    sampling_ids = np.random.choice(retain_set.indices, size=config.sampling_size, replace=False)
    sampled_retain_set = Subset(retain_set.dataset, sampling_ids)
    retain_loader = DataLoader(sampled_retain_set, batch_size=1, shuffle=True)

    num_total_samples = len(forget_set) + len(sampled_retain_set)
    num_retain_samples = len(sampled_retain_set)

    # indices = torch.randperm(len(full_retain_loader.dataset))[:config.sampling_size]
    # subset = Subset(old_dr.dataset, indices)
    # sampled_dr = DataLoader(subset, batch_size=old_dr.batch_size, shuffle=True)

    # num_total = len(sampled_dr.dataset) + len(df.dataset)
    # num_to_retain = len(sampled_dr.dataset)
    # print('num_total: ', num_total)
    # print('num_to_retain: ', num_to_retain)

    retain_path = os.path.join(temp_dir, "retain_cache.pkl")
    if os.path.exists(retain_path):
        with open(retain_path, "rb") as f:
            G_r, f0_minus_y_r = pickle.load(f)
    else:
        G_r, f0_minus_y_r = delta_w_utils(copy.deepcopy(model), retain_loader, 'complete')
        with open(retain_path, "wb") as f:
            pickle.dump((G_r, f0_minus_y_r), f)
    print("Loaded G_r, f0_minus_y_r")

        
    # Gr_path = os.path.join(temp_dir, 'G_r.npy')
    # f0_minus_r_path = os.path.join(temp_dir, 'f0_minus_y_r.npy')
    # np.save(, f0_minus_y_r)
    # del G_r, f0_minus_y_r

    forget_path = os.path.join(temp_dir, "forget_cache.pkl")
    if os.path.exists(forget_path):
        with open(forget_path, "rb") as f:
            G_f, f0_minus_y_f = pickle.load(f)
    else:
        G_f, f0_minus_y_f = delta_w_utils(copy.deepcopy(model), forget_loader, 'retain') 
        with open(forget_path, "wb") as f:
            pickle.dump((G_f, f0_minus_y_f), f)
    print("Loaded G_f, f0_minus_y_f")

    # np.save(os.path.join(temp_dir, 'G_f.npy'),G_f)
    # np.save(os.path.join(temp_dir, 'f0_minus_y_f.npy'),f0_minus_y_f)
    # del G_f, f0_minus_y_f

    # G_r = np.load(os.path.join(temp_dir, 'G_r.npy'))
    # G_f = np.load(os.path.join(temp_dir, 'G_f.npy'))
    G = np.concatenate([G_r, G_f],axis=1)
    # np.save(os.path.join(temp_dir, 'G.npy'),G)
    # del G, G_f, G_r

    # f0_minus_y_r = np.load(os.path.join(temp_dir, 'f0_minus_y_r.npy'))
    # f0_minus_y_f = np.load(os.path.join(temp_dir, 'f0_minus_y_f.npy'))
    f0_minus_y = np.concatenate([f0_minus_y_r,f0_minus_y_f])
    # np.save(os.path.join(temp_dir, 'f0_minus_y.npy'),f0_minus_y)
    # del f0_minus_y, f0_minus_y_r, f0_minus_y_f

    # G = np.load(os.path.join(temp_dir, 'G.npy'))
    # print("shape of G: ", G.shape)
    theta = G.transpose().dot(G) + num_total_samples * config.ntk_weight_decay * np.eye(G.shape[1])
    theta_inv = np.linalg.inv(theta)
    del theta 
    # del G

    breakpoint()

    np.save(os.path.join(temp_dir, 'theta.npy'), theta)
    del theta

    # G = np.load(os.path.join(temp_dir, 'G.npy'))
    # f0_minus_y = np.load(os.path.join(temp_dir, 'f0_minus_y.npy'))
    w_complete = -G.dot(theta_inv.dot(f0_minus_y))
    np.save(os.path.join(temp_dir, 'theta_inv.npy'),theta_inv)
    np.save(os.path.join(temp_dir, 'w_complete.npy'),w_complete)
    del G, f0_minus_y, theta_inv, w_complete 

    G_r = np.load(os.path.join(temp_dir, 'G_r.npy'))
    theta_r = G_r.transpose().dot(G_r) + num_retain_samples * config.ntk_weight_decay * np.eye(G_r.shape[1])
    theta_r_inv = np.linalg.inv(theta_r)
    del theta
    # del G_r

    # np.save(os.path.join(temp_dir, 'theta_r.npy'),theta_r)
    # del theta_r

    G_r = np.load(os.path.join(temp_dir, 'G_r.npy'))
    f0_minus_y_r = np.load(os.path.join(temp_dir, 'f0_minus_y_r.npy'))
    w_retain = -G_r.dot(theta_r_inv.dot(f0_minus_y_r))
    np.save(os.path.join(temp_dir, 'theta_r_inv.npy'),theta_r_inv)
    np.save(os.path.join(temp_dir, 'w_retain.npy'),w_retain)
    del G_r, f0_minus_y_r, theta_r_inv, w_retain 

    w_complete = np.load(os.path.join(temp_dir, 'w_complete.npy'))
    w_retain = np.load(os.path.join(temp_dir, 'w_retain.npy'))
    delta_w = (w_retain-w_complete).squeeze()

    # model_init = ntk_init(cfg, seed)
    model_init = ntk_init(config, model)
    m_pred_error = vectorize_params(model)-vectorize_params(model_init)-w_retain.squeeze()
    print(f"Delta w -------: {np.linalg.norm(delta_w)}")

    inner = np.inner(delta_w/np.linalg.norm(delta_w),m_pred_error/np.linalg.norm(m_pred_error))
    print(f"Inner Product--: {inner}")

    if inner<0:
        angle = np.arccos(inner)-np.pi/2
        print(f"Angle----------:  {angle}")

        predicted_norm=np.linalg.norm(delta_w) + 2*np.sin(angle)*np.linalg.norm(m_pred_error)
        print(f"Pred Act Norm--:  {predicted_norm}")
    else:
        angle = np.arccos(inner) 
        print(f"Angle----------:  {angle}")

        predicted_norm=np.linalg.norm(delta_w) + 2*np.cos(angle)*np.linalg.norm(m_pred_error)
        print(f"Pred Act Norm--:  {predicted_norm}")

    predicted_scale=predicted_norm/np.linalg.norm(delta_w)
    print(f"Predicted Scale:  {predicted_scale}")

    scale=predicted_scale
    direction = get_delta_w_dict(delta_w,model)
    for k,p in model.named_parameters():
        p.data += (direction[k]*scale).to(config.device)

