import random
from pathlib import Path

from termcolor import colored

import numpy as np

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

from pareto.metrics import topk_accuracy
from model import MultiLeNet
from pareto.utils import evenly_dist_weights
from pareto.optim import find_min_norm_element

import argparse
import copy

from min_norm_solver import MinNormSolver
from metric import hypervolumn, get_pareto_front

import os
# import setGPU

import torch.backends.cudnn as cudnn

from epo_lp import EPO_LP
import pickle


@torch.no_grad()
def evaluate(network, dataloader, device, closures, header=''):
    num_samples = 0

    losses = np.zeros(2)
    top1s = np.zeros(2)
    network.train(False)
    for images, labels in dataloader:
        batch_size = len(images)
        num_samples += batch_size
        images = images.to(device)
        labels = labels.to(device)
        logits = network(images)
        losses_batch = [c(network, logits, labels).item() for c in closures]
        losses += batch_size * np.array(losses_batch)
        top1s[0] += batch_size * topk_accuracy(logits[0], labels[:, 0], k=1)
        top1s[1] += batch_size * topk_accuracy(logits[1], labels[:, 1], k=1)
    losses /= num_samples
    top1s /= num_samples

    loss_msg = '[{}]'.format('/'.join([f'{loss:.6f}' for loss in losses]))
    top1_msg = '[{}]'.format('/'.join([f'{top1 * 100.0:.2f}%' for top1 in top1s]))
    msgs = [
        f'{header}:' if header else '',
        'loss', colored(loss_msg, 'yellow'),
        'top@1', colored(top1_msg, 'yellow')
    ]
    print(' '.join(msgs))
    return losses, top1s

def getNumParams(params):
    numParams, numTrainable = 0, 0
    for param in params:
        npParamCount = np.prod(param.data.shape)
        numParams += npParamCount
        if param.requires_grad:
            numTrainable += npParamCount
    return numParams, numTrainable

def circle_points(K, min_angle=None, max_angle=None):
    # generate evenly distributed preference vector
    ang0 = np.pi / 20. if min_angle is None else min_angle
    ang1 = np.pi * 9 / 20. if max_angle is None else max_angle
    angles = np.linspace(ang0, ang1, K+4)
    x = np.cos(angles)[2:-2]
    y = np.sin(angles)[2:-2]
    return np.c_[x, y]

def energy_grad_to_vec(networks, optimizers, energy):
    energy_grad_all = []
    
    for optimizer in optimizers:
        optimizer.zero_grad()
    
    energy.backward()
    for network in networks:
        energy_grad = []
        for name, param in network.named_parameters():
            if param.grad is not None:
                tmp = param.grad.detach().clone()
                # tmp.data = torch.clip(tmp.data, min=-.1, max=.1)
                energy_grad += tmp.view(-1).tolist()

        energy_grad = torch.tensor([energy_grad]).T # n_param, 1

        # normalize
        energy_grad /= (energy_grad**2).sum().sqrt()

        energy_grad_all.append(energy_grad)

    return energy_grad_all

def grad_to_vec(network, optimizer, losses, ret_shared_idx=False):
    loss_grads = []
    shared_idx = []
    count = 0
    for task_id in range(len(losses)):
        grad = []
        optimizer.zero_grad()
        if task_id < len(losses) - 1:
            losses[task_id].backward(retain_graph=True)
        else:
            losses[task_id].backward()

        for name, param in network.named_parameters():
            if param.grad is not None:
                tmp = param.grad.detach().clone()
                grad += tmp.view(-1).tolist()
                
                # calculate idx for shared 
                if task_id == 0 and ret_shared_idx:
                    size = param.numel()
                    if not name.startswith('fc3'):
                        shared_idx += list(range(count, count+size))
                    count += size            
        loss_grads.append(grad)
    
    loss_grads = torch.tensor(loss_grads).T # n_param, n_tasks

    if ret_shared_idx:
        return loss_grads, shared_idx
    else:
        return loss_grads


def mgd(G, shared_idx=None, normalize=False):
    """
    :param G: n_param * n_task
    """
    if shared_idx is None:
        shared_idx = list(range(0, G.shape[0]))
    
    g_proj = G.sum(axis=1, keepdim=True)

    if normalize:
        G_shared = G[shared_idx, :]/((G[shared_idx, :]**2).sum(axis=0, keepdim=True).sqrt())
    else:
        G_shared = G[shared_idx, :]
    vecs = []
    for _ in range(G.shape[1]):
        vecs.append(G_shared[:, _].squeeze())
    
    alpha, n_d = solver.find_min_norm_element(vecs)
    alpha = torch.from_numpy(alpha).unsqueeze(1).float()

    g_proj[shared_idx, :] = G_shared.mm(alpha)
    return g_proj, alpha

def project_to_linear_span(G, g, epsilon=0.0):
    """
    :param G: n_param * n_task
    :param g: n_param * 1
    Solving Lse for || G * beta - g||
    """
    beta, LU = torch.solve((G.T).mm(g), (G.T).mm(G) + epsilon * torch.eye(G.shape[1]))
    return beta

def constraint_search(G, g, alpha=0.2, shared_idx=None, mgd_threshold=0.1, constrain_lr=0.1, constrain_iters=100, constrain_epsilon=0.01, normalize=False):
    """
    :param G: n_param * n_task
    :param g: n_param * 1
    :return: project of ev to linear span of gv
    """

    g_proj = G.sum(axis=1, keepdim=True)

    if shared_idx is None:
        shared_idx = list(range(0, G.shape[0]))
    
    if normalize:
        G_norm = ((G[shared_idx, :]**2).sum(axis=0, keepdim=True).sqrt())
        g_norm = ((g[shared_idx, :]**2).sum(axis=0, keepdim=True).sqrt())
        G_shared = G[shared_idx, :]/G_norm
        g_shared = g[shared_idx, :]/g_norm

    else:
        G_shared = G[shared_idx, :]
        g_shared = g[shared_idx, :]/((g[shared_idx, :]**2).sum(axis=0, keepdim=True).sqrt())
    
    # check mgd
    g_mgd, beta = mgd(G_shared, normalize=False)
    g_mgd_norm = (g_mgd**2).sum()
    avg_G_norm = (G_shared**2).sum(axis=0).mean()

    # if is close to Pareto front
    if g_mgd_norm <= mgd_threshold * avg_G_norm:
        g_proj[shared_idx, :] = g_shared * 0.1

    # if is close to boundary, control phase
    else:
        beta.requires_grad = True # n_task, 1
        
        for _ in range(constrain_iters):
            constrain_loss = 0.5 * ((G_shared.mm(beta) + g_shared)**2).sum() - beta.sum() * alpha * g_mgd_norm
            constrain_loss.backward()
            beta.data -= constrain_lr * beta.grad.data
            beta.data *= (beta.data >= 0.)
            if _ > 0 and ((beta.data - last_beta.data)**2).sum() <= constrain_epsilon:
                beta.grad.data.zero_()
                break
            beta.grad.data.zero_()
            last_beta = beta.clone().detach()
        beta_ = beta.detach().clone()
        del beta, last_beta
        g_proj[shared_idx, :] = (G_shared.mm(beta_) + 0.1 * g_shared)
    return g_proj
        
def vec_to_grad(network, optimizer, device, g_p, g_v=None, alpha=0.):
    """
    :param g_p: n_param
    :param g_v:
    :return:
    """
    g_p = g_p.to(device)
    if g_v is not None:
        g_v = g_v.to(device)

    optimizer.zero_grad()
    count = 0
    for name, param in network.named_parameters():
        if param.grad is not None:
            size = param.numel()
            
            param.grad.data += g_p[count:(count+size)].reshape(param.shape)
            if g_v is not None:
                param.grad.data += alpha * g_p[count:(count + size)].reshape(param.shape)
            count += size


def energy(M, device, s=2, grad=False, MAX=10000., EPS=0.0001):
    # M is the loss matrix (list): [[task1, task2, ...], [task1, task2, ...], ....]

    energy = torch.zeros(len(M), len(M)).to(device)

    for net_id1 in range(0,len(M)):
        for net_id2 in range(net_id1+1, len(M)):
            for task_id in range(len(M[0])):
                energy[net_id1, net_id2] += (M[net_id1][task_id] - M[net_id2][task_id])**2
                energy[net_id2, net_id1] += (M[net_id1][task_id] - M[net_id2][task_id])**2

    energy = energy + MAX * torch.eye(energy.shape[0]).to(device)

    if s >= 10:
        energy = (1. / (EPS + energy)).max()
    else:
        energy = ((EPS + energy) ** (-s / 2.)).sum()

    return energy


def weighted_sum_step(images, labels, closures, networks, optimizers, lr_schedulers):
    for net_id, network in enumerate(networks):
        network.train(True)
        logits = network(images)
        losses = [c(network, logits, labels) for c in closures]
        loss = sum(w * l for w, l in zip(prefs[net_id], losses))
        optimizers[net_id].zero_grad()
        loss.backward()
        optimizers[net_id].step()
        lr_schedulers[net_id].step()


def mgd_step(images, labels, closures, networks, optimizers, lr_schedulers):
    loss_grads_all = []
    for net_id, network in enumerate(networks):
        logits = network(images)
        losses = [c(network, logits, labels) for c in closures]

        if net_id == 0:
            # only apply mgd to shared parameters
            loss_grads, shared_idx = grad_to_vec(network, optimizers[net_id], losses, ret_shared_idx=True)
        else:
            loss_grads = grad_to_vec(network, optimizers[net_id], losses, ret_shared_idx=False)
                    
        loss_grads_all.append(loss_grads)

    for net_id, network in enumerate(networks):
        loss_grads, alphas = mgd(loss_grads_all[net_id], shared_idx=shared_idx)
        vec_to_grad(network, optimizers[net_id], device, loss_grads, None, alpha=1.)
            
    for net_id in range(len(networks)):
        optimizers[net_id].step()
        lr_schedulers[net_id].step()

def epo_step(images, labels, closures, networks, optimizers, lr_schedulers, preferences, epo_lps):
    for net_id, network in enumerate(networks):
        logits = network(images)
        losses = [c(network, logits, labels) for c in closures]
        losses_copy = np.stack([loss.clone().detach().data.cpu().numpy() for loss in losses])
        loss_grads = grad_to_vec(network, optimizers[net_id], losses, ret_shared_idx=False)
        GG = loss_grads.T @ loss_grads
        try:
            # Calculate the alphas from the LP solver
            alpha = epo_lps[net_id].get_alpha(losses_copy, G=GG.cpu().numpy(), C=True)
            # if epo_lps[net_id].last_move == "dom":
            #     descent += 1
        except Exception as e:
            print(e)
            alpha = None
        if alpha is None:   # A patch for the issue in cvxpy
            alpha = preferences[net_id,:] / preferences[net_id, :].sum()
            print('manual_adjusts')
        alpha = (n_tasks * torch.from_numpy(alpha)).unsqueeze(1).float()
        grads = loss_grads.mm(alpha)
        vec_to_grad(network, optimizers[net_id], device, grads, None, alpha=0.)
    for net_id in range(len(networks)):
        optimizers[net_id].step()
        lr_schedulers[net_id].step()

def constrain_step(images, labels, closures, networks, optimizers, lr_schedulers, front=[]):
    loss_matrix = []
    for network in networks:
        network.train(True)
        logits = network(images)
        loss_matrix.append([c(network, logits, labels) for c in closures])

    loss_matrix += front
    energy_loss = energy(loss_matrix, device, s=args.energy_s)
    energy_grad_all = energy_grad_to_vec(networks, optimizers, energy_loss)

    loss_grads_all = []
    shared_idx = None
    for net_id, network in enumerate(networks):
        logits = network(images)
        losses = [c(network, logits, labels) for c in closures]

        if net_id == 0 and args.only_shared:
            loss_grads, shared_idx = grad_to_vec(network, optimizers[net_id], losses, ret_shared_idx=True)
        else:
            loss_grads = grad_to_vec(network, optimizers[net_id], losses, ret_shared_idx=False)
                    
        loss_grads_all.append(loss_grads)

    for net_id, network in enumerate(networks):
        loss_grads = constraint_search(loss_grads_all[net_id], energy_grad_all[net_id], alpha=args.constrain_alpha, shared_idx=shared_idx, mgd_threshold=args.threshold, normalize=args.energy_normalize)
        vec_to_grad(network, optimizers[net_id], device, loss_grads, None, alpha=1.)
            
    for net_id in range(len(networks)):
        optimizers[net_id].step()
        lr_schedulers[net_id].step()


def train(args):

    # prepare models/optimizers
    num_models = args.num_models
    networks = [MultiLeNet() for i in range(num_models)]
    for network in networks:
        network.to(device)

    optimizers = [Adam(networks[i].parameters(), lr=lr, weight_decay=weight_decay) for i in range(num_models)]
    lr_schedulers = [CosineAnnealingLR(optimizers[i], num_epochs * len(trainloader)) for i in range(num_models)]

    if args.save_random_init:
        save_models(networks, optimizers, lr_schedulers, -1, args, 'random' + str(args.seed), all_losses=[], all_top1s=[])

    # all point
    all_losses = []
    all_top1s = []
    all_train_losses = []
    
    start_epoch = 0
    if args.load_name:
        start_epoch, all_losses, all_top1s = load_models(networks, optimizers, lr_schedulers, args, args.load_name, device)
        start_epoch += 1
    
    # prepare losses
    criterion = F.cross_entropy
    closures = [lambda n, l, t: criterion(l[0], t[:, 0]), lambda n, l, t: criterion(l[1], t[:, 1])]

    # first evaluation
    for net_id, network in enumerate(networks):
        evaluate(network, testloader, device, closures, f'{ckpt_name}')
    
    # for mgd, mixed, epo, need to initialize all the gradients
    if args.algo in ['mgd', 'epo', 'png']:
        trainiter = iter(trainloader)
        images, labels = next(trainiter)
        images = images.to(device)
        labels = labels.to(device)
        for net_id, network in enumerate(networks):
            network.train(True)
            logits = network(images)
            losses = [c(network, logits, labels) for c in closures]
            loss = sum(w * l for w, l in zip(prefs[net_id], losses))
            loss.backward()
            optimizers[net_id].zero_grad()
        front = get_pareto_front(np.array(all_losses), type='loss')
    
    if args.algo == 'epo':
        _, n_params = getNumParams(networks[0].parameters())
        preferences = circle_points(args.num_models, min_angle=0.0001*np.pi/2, max_angle=0.9999*np.pi/2)  # preference
        epo_lps = []
        for _, pref in enumerate(preferences[::-1]):
            epo_lps.append(EPO_LP(m=n_tasks, n=n_params, r=pref))
    
    num_steps = len(trainloader)
    for epoch in range(start_epoch, num_epochs):
        trainiter = iter(trainloader)
        for _ in range(1, num_steps + 1):

            images, labels = next(trainiter)
            images = images.to(device)
            labels = labels.to(device)

            if args.algo == 'ws':
                weighted_sum_step(images, labels, closures, networks, optimizers, lr_schedulers)
            elif args.algo == 'mgd':
                if epoch <= start_energy:
                    weighted_sum_step(images, labels, closures, networks, optimizers, lr_schedulers)
                else:
                    mgd_step(images, labels, closures, networks, optimizers, lr_schedulers)
            elif args.algo == 'epo':
                epo_step(images, labels, closures, networks, optimizers, lr_schedulers, preferences, epo_lps)

            elif args.algo == 'png':
                if epoch <= start_energy:
                    weighted_sum_step(images, labels, closures, networks, optimizers, lr_schedulers)
                elif epoch % 2 == 0:
                    constrain_step(images, labels, closures, networks, optimizers, lr_schedulers, front=front)
                else:
                    mgd_step(images, labels, closures, networks, optimizers, lr_schedulers)
            else:
                print('Method not implemented')
                return

        for net_id, network in enumerate(networks):
            loss, top1 = evaluate(network, testloader, device, closures, f'{ckpt_name}: {epoch+1}/{num_epochs}')
            train_loss, train_top1 = evaluate(network, trainloader, device, closures, f'{ckpt_name}: {epoch+1}/{num_epochs}')
            if epoch >= args.epoch_to_eval_pareto:
                all_losses.append(list(loss))
                all_top1s.append(list(top1))
                all_train_losses.append(list(train_loss))
        
        if epoch >= args.epoch_to_eval_pareto:
            hv_loss = hypervolumn(np.array(all_losses), type='loss', ref=np.ones(2) * 0.6)
            hv_acc = hypervolumn(np.array(all_top1s), type='acc', ref=-np.ones(2) * 0.6)
            front = get_pareto_front(np.array(all_train_losses), type='loss')
            print('hv_loss: ', hv_loss)
            print('hv_acc: ', hv_acc)

        print('------------------------')

        if args.epoch_to_store and epoch in args.epoch_to_store:
            save_models(networks, optimizers, lr_schedulers, epoch, args, args.store_name + '_' + str(epoch), all_losses=all_losses, all_top1s=all_top1s)

    # saving
    save_models(networks, optimizers, lr_schedulers, epoch, args, args.store_name + '_' + str(args.constrain_alpha) + '_' + str(args.threshold), all_losses=all_losses, all_top1s=all_top1s)

def save_models(networks, optimizers, lr_schedulers, epoch, args, name, all_losses=[], all_top1s=[]):
    root_path = Path(__file__).resolve().parent
    ckpt_path = os.path.join(root_path, args.save_path)

    ckpt = {
        'state_dict': [network.state_dict() for network in networks],
        'optimizer': [optimizer.state_dict() for optimizer in optimizers],
        'lr_scheduler': [lr_scheduler.state_dict() for lr_scheduler in lr_schedulers],
        'epoch': epoch,
        'all_losses': all_losses,
        'all_top1s': all_top1s
    }
    torch.save(ckpt, os.path.join(ckpt_path, name))
    print('model saved ...')

def load_models(networks, optimizers, lr_schedulers, args, name, device):
    root_path = Path(__file__).resolve().parent
    ckpt_path = os.path.join(root_path, args.save_path, name)

    ckpt = torch.load(ckpt_path, map_location=device)

    for net_id, network in enumerate(networks):
        network.load_state_dict(ckpt['state_dict'][net_id])
        optimizers[net_id].load_state_dict(ckpt['optimizer'][net_id])
        lr_schedulers[net_id].load_state_dict(ckpt['lr_scheduler'][net_id])

    print('model loaded ...')

    if 'all_losses' in ckpt.keys() and 'all_top1s' in ckpt.keys():
        return ckpt['epoch'], ckpt['all_losses'], ckpt['all_top1s']
    else:
        return ckpt['epoch'], [], []

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # common
    parser.add_argument('--num_epochs', type=int, default=100, 
                        help='num_epochs to train the model')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='num_workers')
    parser.add_argument('--seed', type=int, default=0,
                        help='seed')
    parser.add_argument('--dataset', type=str, default='mnist',
                        help='dataset to run. Choose from {mnist, fashion, fashion_and_mnist}')
    
    # i/o
    parser.add_argument('--save_path', type=str, default='result',
                        help='path to save the model')
    parser.add_argument('--store_name', type=str, default='ckpt',
                        help='checkpoint name to store')
    parser.add_argument('--load_name', type=str, default=None,
                        help='checkpoint name to load')
    parser.add_argument('--epoch_to_store', type=int, nargs='+', default=None,
                        help='which epoches to store model')
    parser.add_argument('--save_random_init', type=int, nargs='+', default=0,
                        help='whether to save random init network')

    # method
    parser.add_argument('--algo', type=str, default='mgd',
                        help='algorithm for learning. ws: weighted_sum; mgd: multiple_gradient_descent; epo; png')
    
    # optimizer
    parser.add_argument('--lr', type=float, default=0.002,
                        help='learning rate')
    parser.add_argument('--batch_size', type=int, default=512,
                        help='batch size')

    # Pareto Explore
    parser.add_argument('--num_models', type=int, default=5,
                        help='number of models used to approximate Pareto Front')
    parser.add_argument('--start_energy', type=int, default=60,
                        help='epoch start to run first order algorithm')
    parser.add_argument('--only_shared', type=int, default=0,
                        help='only apply to shared parameter')
    parser.add_argument('--energy_s', type=int, default=2,
                        help='norm used for energy function')
    parser.add_argument('--energy_normalize', type=int, default=1,
                        help='only apply to shared parameter')
    parser.add_argument('--constrain_alpha', type=float, default=0.5,
                        help='alpha for constrain')
    parser.add_argument('--threshold', type=float, default=0.1,
                        help='threshold')

    # eval
    parser.add_argument('--epoch_to_eval_pareto', type=int, default=60, help='epoch start to evaluate the quality of pareto set')

    args = parser.parse_args()

    # prepare hyper-parameters
    seed = args.seed
    cuda_enabled = True
    cuda_deterministic = False

    batch_size = args.batch_size
    num_workers = args.num_workers

    lr = args.lr
    momentum = 0.9
    weight_decay = 0.0

    num_epochs = args.num_epochs

    start_energy = args.start_energy

    prefs = evenly_dist_weights(args.num_models + 2, 2)

    n_tasks = 2

    # prepare path

    root_path = Path(__file__).resolve().parent
    ckpt_path = root_path / args.save_path

    root_path.mkdir(parents=True, exist_ok=True)
    ckpt_path.mkdir(parents=True, exist_ok=True)

    args.store_name = args.store_name + 'seed' + str(args.seed)
    ckpt_name = args.store_name

    # fix random seed

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda_enabled and torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # prepare device

    if cuda_enabled and torch.cuda.is_available():
        device = torch.device('cuda')
        if cuda_deterministic:
            cudnn.benchmark = False
            cudnn.deterministic = True
        else:
            cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # prepare dataset
    if args.dataset == 'mnist':
        with open('data/multi_mnist.pickle', 'rb') as f:
            trainX, trainLabel, testX, testLabel = pickle.load(f)

    # MultiFashionMNIST: multi_fashion.pickle
    if args.dataset == 'fashion':
        with open('data/multi_fashion.pickle', 'rb') as f:
            trainX, trainLabel, testX, testLabel = pickle.load(f)

    # Multi-(Fashion+MNIST): multi_fashion_and_mnist.pickle
    if args.dataset == 'fashion_and_mnist':
        with open('data/multi_fashion_and_mnist.pickle', 'rb') as f:
            trainX, trainLabel, testX, testLabel = pickle.load(f)

    trainX = torch.from_numpy(trainX.reshape(120000, 1, 36, 36)).float()
    trainLabel = torch.from_numpy(trainLabel).long()
    testX = torch.from_numpy(testX.reshape(20000, 1, 36, 36)).float()
    testLabel = torch.from_numpy(testLabel).long()

    train_set = torch.utils.data.TensorDataset(trainX, trainLabel)
    test_set = torch.utils.data.TensorDataset(testX, testLabel)

    trainloader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True)
    testloader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=False)

    print('==>>> dataset: {}'.format(args.dataset))
    print('==>>> total trainning batch number: {}'.format(len(trainloader)))
    print('==>>> total testing batch number: {}'.format(len(testloader)))

    
    solver = MinNormSolver()
    train(args)
