import argparse
import pickle
import copy
import os
import math

from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import TensorDataset

from common import set_random_seed
from ewc_model import ContinualMultiheadEWC

from multihead_bnn_learner import ContinualMultiheadBNN
from simple_tasksets import *
from functools import reduce

DEBUG = False


def train_models(models, loss_function, task_id, learning_rate, n_epochs, trn_loader, device):
    model_params = [list(m.shared_net.parameters()) + list(m.linear_heads[task_id].parameters()) for m in
                    models]
    model_optimizers = [torch.optim.Adam(m_params, lr=learning_rate, weight_decay=1e-5) for m_params in model_params]
    for iter in range(n_epochs):
        for task_data, task_labels in trn_loader:
            task_data = task_data.to(device)
            task_labels = task_labels.to(device)

            for i, model in enumerate(models):
                model_optimizers[i].zero_grad()
                train_acc = model.loss(task_data, task_labels.float(), loss_function, task_id=task_id, is_test=False)
                train_acc.backward()
                model_optimizers[i].step()
        if DEBUG:
            with torch.no_grad():
                print(task_id, train_acc.item(),
                      model.loss(task_data, task_labels.float(), loss_function, task_id=task_id, is_test=False).item(),
                      model.loss(task_data, task_labels.float(), loss_function, task_id=task_id, is_test=True).item())


def accuracy(out, expected):
    return (torch.max(out, dim=1)[1] == torch.max(expected, dim=1)[1]).sum() / out.shape[0]


PRE_NOISE = 1e-1
POST_NOISE = 1e-4

def calculate_bound(model, prev_models, model_name, curr_task_id, kl_weight, trn, zero_one_loss,
                    device, noise_level=PRE_NOISE):
    trn_loader = DataLoader(trn, batch_size=len(trn))
    for d_, t_ in trn_loader:
        trn_data = d_.to(device)
        trn_labels = t_.to(device)
    with torch.no_grad():
        prev_model = copy.deepcopy(model)
        prev_model.shared_net.load_state_dict(prev_models[model_name])
        kl_surrogate = 0.0
        NOISE_CONST = 0.5 / (noise_level**2)
        for name, param in prev_model.shared_net.named_parameters():
            new_par = model.shared_net.state_dict()[name]
            if PRE_NOISE!=POST_NOISE:
                kl_surrogate += max(0, (NOISE_CONST * ((new_par - param).pow(2)).sum()+POST_NOISE**2)-0.5+math.log(PRE_NOISE/POST_NOISE))
            else:
                kl_surrogate += (NOISE_CONST * (new_par - param).pow(2)).sum()
        if not isinstance(model, ContinualMultiheadBNN):
            trn_losses = model.loss(trn_data, trn_labels, zero_one_loss, task_id=curr_task_id,
                                    specific_seed=args.seed,
                                    is_test=True)
            bound_loss_plus_kl = reduce(lambda x, y: x + y, trn_losses) / len(trn_losses) + kl_weight * kl_surrogate
        else:
            bound_loss_plus_kl = model.loss(trn_data, trn_labels, zero_one_loss, task_id=curr_task_id,
                                            is_test=False)
    return bound_loss_plus_kl


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_sample_size', default=60000, type=int,
                        help="Number of training examples in each task")
    parser.add_argument('--domain', default="split_mnist", type=str,
                        help="Problem domain: mnist_multiclass, split_mnist, split_cifar")
    parser.add_argument('--n_tasks', default=120, type=int,
                        help="Number of tasks")
    parser.add_argument('--batch_size', default=128, type=int,
                        help="Gradient batch size")
    parser.add_argument('--train_steps', default=1, type=int,
                        help="Total grad updates in training")
    parser.add_argument('--lr', default=1e-3, type=float,
                        help="Learning rate")
    parser.add_argument('--net_size', default=400, type=int,
                        help="Shared net size")
    parser.add_argument('--net_depth', default=3, type=int,
                        help="Shared net depth")
    parser.add_argument('--net_conv', default="conv", type=str,
                        help="conv/lin - type of shared net layers")
    parser.add_argument('--kl_weight', default=1e-5, type=float,
                        help="KL lambda")
    parser.add_argument('--seed', type=int, default=805287, help="Random seed")
    # [42, 11, 451, 1337, 805287]
    return parser


EPS = 1e-7


def get_tasks(n_samples, n_tasks, domain, is_cnn=True):
    flatten = not is_cnn
    out_dim = 2
    task_dict = {}
    if domain == "10d":
        task_dict["positive_correlation_10d"] = get_positive_correlated_taskset(n_samples, n_tasks, num_dims=10)
        # task_dict["distractors_10d"] = get_positive_correlated_with_distractors_taskset(n_samples, n_tasks,
        #                                                                                 frac_distractors=0.2,
        #                                                                                 num_dims=10)
        task_dict["gradual_shift_10d"] = get_gradual_shift_taskset(n_samples, n_tasks, num_dims=10)
        task_dict["orthogonal_alternating_10d"] = get_orthogonal_taskset(n_samples, n_tasks, alternating=True,
                                                                         num_dims=10)
        task_dict["orthogonal_shift_10d"] = get_orthogonal_taskset(n_samples, n_tasks, alternating=False, num_dims=10)
        # task_dict["rand_10d"] = get_rand_taskset(n_samples, n_tasks, num_dims=10)
        DIMS = 10
    elif domain == "mnist_multiclass":
        task_dict["permuted_mnist_multiclass"] = get_multiclass_permuted_mnist(n_samples, n_tasks, flatten=flatten)
        DIMS = (1, 28, 28) if is_cnn else 28 ** 2
        out_dim = 10
    elif domain == "cifar_multiclass":
        task_dict["permuted_cifar_multiclass"] = get_multiclass_permuted_cifar10(n_samples, n_tasks, flatten=flatten)
        DIMS = (3, 32, 32) if is_cnn else 3 * (32 ** 2)
        out_dim = 10
    elif domain == "split_mnist":
        task_dict["split_mnist_grouped"] = get_split_mnist(n_samples, n_tasks, grouped=True, flatten=flatten)
        DIMS = (1, 28, 28) if is_cnn else 28 ** 2
    elif domain == "split_cifar":
        task_dict["split_cifar_grouped"] = get_split_cifar10(n_samples, n_tasks, grouped=True, flatten=flatten)
        DIMS = (3, 32, 32) if is_cnn else 3 * (32 ** 2)
    else:
        raise ValueError(f"Not supported domain {args.domain}")
    return task_dict, DIMS, out_dim


def initialize_models(args, network_structure, DIMS, out_dim, device):
    pb_model_rolling = ContinualMultiheadBNN(shared_structure=network_structure, in_size=DIMS, out_size=out_dim,
                                             device=device,
                                             kl_weight=args.kl_weight, n_MC=3, use_rolling_prior=True,
                                             pre_var=PRE_NOISE, post_var=POST_NOISE)
    ewc_model = ContinualMultiheadEWC(shared_structure=network_structure, in_size=DIMS, out_size=out_dim,
                                      device=device, lamda=100, noise_level=POST_NOISE)
    model_dict = {"EWC": ewc_model,
                  "Rolling prior": pb_model_rolling
                  }

    return model_dict

CONVS=64

if __name__ == "__main__":
    args = get_parser().parse_args()
    set_random_seed(args.seed)
    exp_name = "forward_with_tanh_actually"

    n_samples = args.train_sample_size
    n_tasks = args.n_tasks
    batch_size = args.batch_size
    n_train_epochs = args.train_steps
    learning_rate = args.lr
    if args.net_conv == "conv":
        if args.net_size == CONVS:
            args.net_size += 1
        network_structure = [CONVS] * (args.net_depth - 1) + [args.net_size]
    else:
        network_structure = [args.net_size] * args.net_depth
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    zero_one_loss = lambda out, expected: 1 - accuracy(out, expected)
    task_dict, DIMS, out_dim = get_tasks(n_samples, n_tasks, args.domain, args.net_conv == "conv")
    DELTA = 0.01
    K = 1 - 1.0 / out_dim
    CONST = np.log(1 / DELTA) * args.kl_weight + np.power(K, 2) / n_samples / args.kl_weight / 8
    loss_function = nn.CrossEntropyLoss(reduction='mean').to(device)

    for task_name, task_gen in task_dict.items():
        model_dict = initialize_models(args, network_structure, DIMS, out_dim, device)
        DO_BOUNDS = True

        train_sets = []
        test_sets = []
        test_losses = {}
        if DO_BOUNDS:
            test_bounds = {}
            prev_models = {name: model.shared_net.state_dict() for name, model in model_dict.items()}

        for model_name in model_dict.keys():
            test_losses[model_name] = []
            if DO_BOUNDS:
                test_bounds[model_name] = []
        for i, task in enumerate(task_gen):
            print(f"{task_name}: Training epoch {i}")
            trn, tst = TensorDataset(task[0], task[1]), TensorDataset(task[2], task[3])
            n_tst_samples = len(task[3])
            trn_loader = DataLoader(trn, batch_size=batch_size, shuffle=True)
            for model in model_dict.values():
                model.adapt_new_task(i)
            train_models(model_dict.values(), loss_function, i, learning_rate, n_train_epochs, trn_loader, device)

            # Record test losses (generalization)
            test_data = task[2].to(device)
            test_labels = task[3].to(device)
            with torch.no_grad():
                for model_name, model in model_dict.items():
                    test_losses[model_name].append(
                        model.loss(test_data, test_labels, zero_one_loss, task_id=i, is_test=True).to("cpu").numpy())
                    if DEBUG or True:
                        print(model_name, test_losses[model_name][-1])

            # Record test bwt
            if DO_BOUNDS:
                for name, model in model_dict.items():
                    forward_bound = calculate_bound(model, prev_models, name, i, args.kl_weight, trn,
                                                    zero_one_loss, device)
                    test_bounds[name].append(forward_bound.cpu().item())
                    if DEBUG or True:
                        print("bound:", name, test_bounds[name][-1])

            test_loader = DataLoader(tst, batch_size=n_tst_samples)
            for test_set_i in test_loader:
                test_sets.append(test_set_i)
            if DO_BOUNDS:
                prev_models = {}
                for name, model in model_dict.items():
                    prev_models[name] = copy.deepcopy(model.shared_net.state_dict())
            if "EWC" in model_dict.keys():
                model_dict["EWC"].get_previous_training(i, (task[0], task[1]))

        # ------------------------------------------------------------------------------------
        # Record outputs, plot results

        with open(f"results_{exp_name}_{args.domain}_{task_name}_{args.seed}.pkl", "wb") as fptr:
            if DO_BOUNDS:
                pickle.dump((test_losses, test_bounds), fptr)
            else:
                pickle.dump(test_losses, fptr)

        # with open(f"results_{exp_name}_{args.domain}_{task_name}_{args.seed}.pkl", "rb") as fptr:
        #     test_losses, test_bounds = pickle.load(fptr)
        # model_dict.pop("Rolling prior", None)

        colors_array = ["red", "blue", "green", "black", "orange", "purple", "yellow", "brown", "pink", "beige", "cyan",
                        "gray"]
        model_colors = {model_name: colors_array[i] for i, model_name in enumerate(model_dict.keys())}
        plot_range = np.arange(n_tasks)

        # show bwt(loss) vs bwt_bound (loss) -> low is good
        consts_array = np.arange(1, n_tasks+1)
        # consts_array = (np.log(1 / DELTA) * args.kl_weight / consts_array) + (np.power(K, 2) / n_samples / args.kl_weight / 8)

        #FRAC=0.005
        for i in range(n_tasks):
            # lambda_t = args.kl_weight*1000
            # delta_2 = math.exp(-FRAC*(i+1)*n_samples)
            # consts_array[i] = ((np.log(1 / DELTA) * args.kl_weight) +
            #                    K+
            #             lambda_t*math.log((1-delta_2)*math.exp(K/lambda_t*(math.sqrt(FRAC/2))-1)+
            #                                     delta_2))
            consts_array[i] = K*np.sqrt(np.sqrt(i+1))/np.sqrt(2*n_samples) + K*np.log(2)/((i+1)*np.sqrt(i+1))+np.log(1/DELTA)*args.kl_weight
            consts_array[i] = min(consts_array[i], CONST)

        for model_name in model_dict.keys():
            plt.figure()
            if DO_BOUNDS:
                plt.scatter(plot_range, np.cumsum(test_bounds[model_name]) / np.arange(1,
                                                                            n_tasks+1)+consts_array,
                            label=f"{model_name}_bound",
                            marker="+", color=model_colors[model_name])
            plt.scatter(plot_range,
                        np.cumsum(test_losses[model_name]) / np.arange(1,
                                                                            n_tasks+1),
                        label=f"{model_name}_mean_tests", color=model_colors[model_name])
            print(f"{task_name}: {model_name} final test loss: {test_losses[model_name][-1]}, final bound: {test_bounds[model_name][-1]}")
            plt.legend()
            plt.xlabel("Task number")
            plt.ylabel("test loss")
            plt.savefig(f"gen_and_bound_{exp_name}_{task_name}_{model_name}.jpg")
