import argparse
import copy
import math
import pickle

import numpy as np
import matplotlib.pyplot as plt
import torch

from common import set_random_seed
from sgld import SGLD


# x is a subset of size d (total dim is 4*d), w is zero-packed otherwise, y=x*w_true+noise, distance=||w_true1-2_true_2||^2

def generate_alternating(mode, n_tasks, n_samples, total_dim):
    with torch.no_grad():
        w_true1 = get_random_w(mode, total_dim)
        w_true2 = get_random_w(mode, total_dim)
    for i in range(n_tasks):
        if i%2==0:
            yield get_task_data(n_samples, total_dim, w_true1), w_true1
        else:
            yield get_task_data(n_samples, total_dim, w_true2), w_true2


def get_random_w(mode, total_dim):
    w_true1 = torch.randn(1, total_dim)
    if mode == LINEAR:
        zeroed_inds = torch.randperm(total_dim)[(total_dim//5):]
        w_true1[0, zeroed_inds] = 0
    return w_true1

def generate_interpolate_extrapolate(mode, n_tasks, n_samples, total_dim, type="interpolate"):
    with torch.no_grad():
        w_true1 = torch.ones(1, total_dim) * (n_tasks//8)
        w_true4 = -torch.ones(1, total_dim) * (n_tasks//8)
        w_true2 = torch.concat((torch.ones(1, total_dim//2),-torch.ones(1, total_dim//2)), dim=1) * (n_tasks // 8)
        w_true3 = torch.concat((-torch.ones(1, total_dim // 2), torch.ones(1, total_dim // 2)), dim=1) * (n_tasks // 8)
        task_seq = []
        for i in range(n_tasks//4):
            task_seq.extend([w_true1, w_true2, w_true4, w_true3])
            ws = (w_true1+w_true2)/2,(w_true2+w_true3)/2,(w_true3+w_true4)/2,(w_true4+w_true1)/2
            w_true1, w_true2, w_true3, w_true4 = ws
    for i in range(n_tasks):
        if type=="interpolate":
            yield get_task_data(n_samples, total_dim, task_seq[i]), task_seq[i]
        else:
            yield get_task_data(n_samples, total_dim, task_seq[n_tasks-1-i]), task_seq[n_tasks-1-i]


def generate_swap(mode, n_tasks, n_samples, total_dim):
    with torch.no_grad():
        w_true1 = get_random_w(mode, total_dim)
        w_true2 = get_random_w(mode, total_dim)
    for i in range(n_tasks):
        if i < n_tasks//2:
            yield get_task_data(n_samples, total_dim, w_true1), w_true1
        else:
            yield get_task_data(n_samples, total_dim, w_true2), w_true2

DEFAULT_NOISE_LEVEL = 0.3

def generate_gradual_change(mode, n_tasks, n_samples, total_dim, noise_level=DEFAULT_NOISE_LEVEL):
    with torch.no_grad():
        w_true = get_random_w(mode, total_dim)
        w_new_normalized = w_true/(torch.sum(w_true, dim=1, keepdim=True)**2)
    for i in range(n_tasks):
        yield get_task_data(n_samples, total_dim, w_new_normalized), w_new_normalized
        with torch.no_grad():
            w_old = w_new_normalized
            w_new = w_old+torch.randn_like(w_old)*(noise_level**2)
            norm_new = torch.sum(w_new, dim=1, keepdim=True)**2
            if norm_new>1: # ensure change is not too big
                w_new_normalized = w_new/(torch.sum(w_new, dim=1, keepdim=True)**2)
            else:
                w_new_normalized = w_new

def get_task_data(n_samples, total_dim, w_true, noise_level=DEFAULT_NOISE_LEVEL):
    X = torch.randn(n_samples, total_dim)
    y_true = X @ w_true.T
    y_noised = y_true + (noise_level**2) * torch.randn(n_samples, 1)
    return X, y_noised

LINEAR = "LINEAR"
NTK = "NTK"

class LinearNN(torch.nn.Module):
    def __init__(self, in_features, n_params, device, mode):
        super(LinearNN, self).__init__()
        self.mode = mode
        if mode == LINEAR:
            self.linear = torch.nn.Linear(in_features, 1, device=device)
        else:
            self.linear = torch.nn.Linear(in_features, n_params, device=device)
        self.device = device
        self.dim = n_params
        self.out_layers = {}
        self.add_task(0)

    def add_task(self, task_id):
        if task_id not in self.out_layers:
            self.out_layers[task_id] = torch.nn.Linear(self.dim, 1, device=device)

    def forward(self, x, task_id=0):
        if self.mode == LINEAR:
            return self.linear(x)
        else:
            non_linear=torch.nn.functional.relu(self.linear(x))
            return self.out_layers[task_id](non_linear)

INITIAL_TEMPERATURE = DEFAULT_NOISE_LEVEL*0.01
MAX_ITERATIONS = 20

def train_model_on_task_sgld(model, task_x, task_y, batch_size, learning_rate, task_id):
    relevant_parameters = list(model.linear.parameters())+list(model.out_layers[task_id].parameters())
    optimizer = SGLD(relevant_parameters, lr=learning_rate, temperature=INITIAL_TEMPERATURE)
    for iter in range(MAX_ITERATIONS):
        dataset = torch.utils.data.TensorDataset(task_x, task_y)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        for task_data, task_labels in data_loader:
            optimizer.zero_grad()
            train_mse = torch.nn.functional.mse_loss(model(task_data, task_id=task_id), task_labels, reduction='mean')
            train_mse.backward()
            optimizer.step()
        optimizer.T = optimizer.T * 0.5

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_sample_size', default=2048, type=int,#2000
                        help="Number of training examples in each task")
    parser.add_argument('--train_dim', default=3000, type=int,#50,10,5000
                        help="Number of feature dimensions")
    parser.add_argument('--problem', default="interpolate", type=str,
                        help="Problem domain: alternate, swap, gradual,interpolate,extrapolate")
    parser.add_argument('--n_tasks', default=200, type=int,#200
                        help="Number of tasks")
    parser.add_argument('--batch_size', default=128, type=int,
                        help="Gradient batch size")
    parser.add_argument('--lr', default=1e-3, type=float,
                        help="Learning rate")
    parser.add_argument("--mode", default=f"{LINEAR}", type=str,
                        help=f"Model mode: {NTK}, {LINEAR}")
    parser.add_argument('--net_size', default=4000, type=int,#4000
                        help="Shared net size")
    parser.add_argument('--seed', type=int, default=42, help="Random seed")
    # [42, 11, 451, 1337, 805287]
    return parser

TEST_SIZE = 400

def get_mse_loss(model, w_true, task_id, device):
    with torch.no_grad():
        x_test, y_test = get_task_data(TEST_SIZE, w_true.shape[1], w_true, noise_level=0)
        x_test = x_test.to(device)
        y_test = y_test.to(device)
        loss = torch.nn.functional.mse_loss(model(x_test, task_id=task_id), y_test, reduction='mean')
        return loss


def calculate_lin_bound_forget(n_params, n_samples, w_true_list, noise_level=DEFAULT_NOISE_LEVEL):
    T = len(w_true_list)
    expected_forget = 0
    if n_samples > n_params + 1: # under-parametrized
        w_last = w_true_list[-1]
        for i, w_true in enumerate(w_true_list):
            if i == len(w_true_list):
                break
            expected_forget += torch.sum((w_true-w_last)**2).cpu().item()
        return expected_forget/(T-1)

    r = 1.0 - n_samples/n_params # over-parametrized
    assert r > 0
    for i,w_true in enumerate(w_true_list):
        if i==len(w_true_list):
            break
        f1 = (r**T-r**(i+1))*torch.sum(w_true**2).cpu().item()
        f2 = 0
        for j,w_true_j in enumerate(w_true_list):
            if j<=i:
                continue
            f2+= (1-r)*(r**(T-i-1)-r**(j-i)+r**(T-j-1))*torch.sum((w_true-w_true_j)**2).cpu().item()
        f3 = n_params*(noise_level**2)*(r**(i+1)-r**T)/(n_params-n_samples-1)
        expected_forget += f1 + f2 + f3
    return expected_forget/(T-1)

def calculate_lin_bound_gen(n_params, n_samples, w_true_list, noise_level=DEFAULT_NOISE_LEVEL):
    T = len(w_true_list)
    expected_gen = 0
    if n_samples > n_params + 1:  # under-parametrized
        w_last = w_true_list[-1]
        for i, w_true in enumerate(w_true_list):
            expected_gen += torch.sum((w_true - w_last) ** 2).cpu().item()
        g3_alt = n_params*(noise_level**2)/(n_params-n_samples-1)
        return expected_gen / T + g3_alt

    r = 1.0 - n_samples / n_params
    assert r > 0
    for i,w_true in enumerate(w_true_list):
        g1 = (r**T)*torch.sum(w_true**2).cpu().item()
        g2 = 0
        for j,w_true_j in enumerate(w_true_list):
            g2+= ((n_samples/n_params)*(r**(T-i-1)))*torch.sum((w_true-w_true_j)**2).cpu().item()
        expected_gen += g1 + g2
    g3 = n_params*(noise_level**2)*(1-(r**T))/(n_params-n_samples-1)
    return expected_gen/T + g3


def empirical_ntk_approx(f0_func_single, model_params, x1, x2):
    # Compute J(x1)
    jac1 = torch.vmap(torch.func.jacrev(f0_func_single), (None, 0))(model_params, x1.view(1,-1))
    jac1 = [j.flatten(2) for j in jac1.values()]

    # Compute J(x2)
    jac2 = torch.vmap(torch.func.jacrev(f0_func_single), (None, 0))(model_params, x2.view(1,-1))
    jac2 = [j.flatten(2) for j in jac2.values()]

    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0).reshape(-1)
    return result

def calculate_ntk_gen_bound(model0, x_data, y_data, task_id, model, prev_model, device):
    # SGD generalization bound for current task (Theorem 3, second row) from Bennani et. al. 2020
    params = dict(model0.named_parameters())
    def call_single(params, x):
        return torch.func.functional_call(model0, params, x).reshape(-1)

    kernel_x = torch.zeros((len(y_data), len(y_data)), device=device)
    for ind1 in range(len(y_data)):
        for ind2 in range(len(y_data)):
            kernel_x[ind1, ind2] = empirical_ntk_approx(call_single, params, x_data[ind1,:], x_data[ind2,:])
    kernel_inverse = torch.inverse(kernel_x)
    y_normed = y_data - use_prev_model(model, x_data, task_id, prev_model)
    r_t = torch.trace(kernel_x)/len(y_data) * y_normed.T @ kernel_inverse @ y_normed
    return torch.sqrt(torch.abs(r_t)).cpu().item()


def use_prev_model(model, task_x, task_id, prev_model):
    tmp = copy.deepcopy(model.linear)
    model.linear = prev_model
    out_y = model(task_x, task_id=task_id)
    model.linear = tmp
    return out_y

USE_NTK_BOUND = False

if __name__ == '__main__':
    args = get_parser().parse_args()
    set_random_seed(args.seed)
    exp_name = "linear_overparam"

    n_samples = args.train_sample_size
    dim = args.train_dim
    n_tasks = args.n_tasks
    batch_size = args.batch_size
    learning_rate = args.lr
    mode = args.mode
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = LinearNN(dim, args.net_size, device, mode)
    task_fun = generate_swap if args.problem == "swap" else (generate_alternating if args.problem == "alternate"
                                                             else generate_gradual_change if args.problem=="gradual_change"
else (lambda a,b,c,d: generate_interpolate_extrapolate(a,b,c,d, type="interpolate")) if args.problem=="interpolate"
else (lambda a,b,c,d: generate_interpolate_extrapolate(a,b,c,d, type="extrapolate")) )
    task_test_ws = []
    test_loss = []
    test_bound = []
    prev_model = copy.deepcopy(model.linear)
    avg_gen_error = []
    if mode == NTK:
        initial_model = copy.deepcopy(model)
        ntk_gens = []
    forget_loss = []
    for i, (task_data, w_true) in enumerate(task_fun(mode, n_tasks, n_samples, dim)):
        print(f"Task {i}")
        model.add_task(i)
        task_x = task_data[0].to(device)
        task_y = task_data[1].to(device)
        task_test_ws.append(w_true)
        # train on data (SGLD with decreasing noise)
        train_model_on_task_sgld(model, task_x, task_y, batch_size, learning_rate, i)
        # use w_true to create test data and check error, same for forgetting
        with torch.no_grad():
            test_loss.append(get_mse_loss(model, w_true, i, device).cpu().item())
            gen = test_loss[-1]
            if mode == NTK and USE_NTK_BOUND:
                r_i = calculate_ntk_gen_bound(initial_model, task_x, task_y,i,model, prev_model, device)
                ntk_gens.append(r_i)
            if i>0:
                forgetting=0
                for j in range(i):
                    test_loss_j = get_mse_loss(model, task_test_ws[j], j, device).cpu().item()
                    forgetting += test_loss_j - test_loss[j]
                    gen += test_loss_j
                forget_loss.append(forgetting/i)
            avg_gen_error.append(gen / (i + 1))

            tmp = copy.deepcopy(model.linear)
            model.linear = prev_model
            test_bound.append(get_mse_loss(model, w_true, i, device).cpu().item())
            model.linear = tmp
            prev_model = copy.deepcopy(model.linear)

    # plot out the average test loss for models (and the Lin. et. al. value for over-param), same for forgetting
    with open(f"results_{exp_name}_{args.problem}_{args.seed}.pkl", "wb") as fptr:
        if mode == NTK:
            pickle.dump((test_loss, forget_loss, test_bound, task_test_ws, avg_gen_error,ntk_gens), fptr)
        else:
            pickle.dump((test_loss, forget_loss, test_bound, task_test_ws, avg_gen_error), fptr)

    if mode == LINEAR:
        n_params = dim
    else:
        n_params = (dim + 1) * args.net_size + args.net_size + 1
    lin_gens = []
    for i in range(n_tasks):
        with torch.no_grad():
            lin_gens.append(calculate_lin_bound_gen(n_params, n_samples, task_test_ws[:(i+1)]))

    lin_forgets = []
    for i in range(n_tasks):
        if i==0:
            continue
        else:
            with torch.no_grad():
                lin_forgets.append(calculate_lin_bound_forget(n_params, n_samples, task_test_ws[:(i+1)]))

    plt.figure()
    plt.scatter(range(n_tasks),
                    np.cumsum(test_loss) / np.arange(1, n_tasks + 1),
                    label=f"Online test loss")
    plt.scatter(range(n_tasks),
                np.cumsum(test_bound) / np.arange(1, n_tasks + 1),
                label=f"Test bound", marker="_")
    plt.scatter(range(n_tasks),
                avg_gen_error,
                label=f"Avg model error")
    if mode == LINEAR:
        plt.scatter(range(n_tasks),
                lin_gens,
                label=f"Lin expected", marker="_")
    if mode == NTK and USE_NTK_BOUND:
        c = np.max(test_loss)
        delta_ = 0.05
        reminder = 3*c*math.sqrt(math.log2(2/delta_)/(2*n_samples))
        plt.scatter(range(n_tasks),
                    (np.cumsum(ntk_gens) + reminder)/ np.arange(1, n_tasks + 1),
                    label=f"Doan bound", marker="_")
        print((np.cumsum(ntk_gens) + reminder)/ np.arange(1, n_tasks + 1))
    print(f"final test loss: {test_loss[-1]}")
    plt.legend()
    plt.xlabel("Task number")
    plt.ylabel("test MSE loss")
    plt.savefig(f"linear_gen_{exp_name}_{args.problem}.jpg")

    plt.figure()
    plt.scatter(range(1, n_tasks),
                forget_loss,
                label=f"Avg forgetting")
    if mode == LINEAR:
        plt.scatter(range(1, n_tasks),
                lin_forgets,
                label=f"Lin expected", marker="_")
    print(f"average forgetting loss: {forget_loss[-1]}")
    plt.legend()
    plt.xlabel("Task number")
    plt.ylabel("test MSE forgetting loss")
    plt.savefig(f"linear_forget_{exp_name}_{args.problem}.jpg")