from collections import defaultdict
from pathlib import Path
import torch
import wandb
import numpy as np
import src.metrics as metrics
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

from copy import deepcopy
from scipy.optimize import minimize
from collections import OrderedDict
from src.min_norm_solvers import MinNormSolver

from src import resources
from src import utils


def cos(t1, t2):
    t1 = F.normalize(t1, dim=0)
    t2 = F.normalize(t2, dim=0)

    dot = (t1 * t2).sum(dim=0)

    return dot

def pair_cos_with_d(pair, d):
    length = pair.size(0)

    dot_value = [cos(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_cos(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
           dot_value.append(cos(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value


def train_search(device,
                 start_epoch,
                 max_epochs,
                 tasks,
                 trainloader_weight,
                 trainloader_arch,
                 model,
                 loss,
                 optimizer_weight,
                 optimizer_arch,
                 exp_dir):

    writer = SummaryWriter(log_dir=exp_dir)

    iter_per_epoch = len(
        trainloader_weight.dataset) // trainloader_weight.batch_size
    total_iter = iter_per_epoch * max_epochs
    delay_epochs = max_epochs // 20

    model.train()
    for epoch in range(start_epoch, max_epochs + 1):

        model.warmup_flag = (epoch <= delay_epochs)
        # set the gumbel temperature according to a linear schedule
        model.gumbel_temp = min(5.0 - (epoch - delay_epochs - 1)
                                / (max_epochs - delay_epochs - 1) * (5.0 - 0.1), 5.0)

        arch_loss = 0
        arch_counter = 0

        if epoch > delay_epochs:
            print('modifying architecture...')

            # we reset the arch optimizer state
            optimizer_arch.state = defaultdict(dict)

            # we use current batch statistics in search period
            model.freeze_encoder_bn_running_stats()

            for samples_search in trainloader_arch:

                inputs_search = samples_search['image'].to(
                    device, non_blocking=True)
                target_search = {task: samples_search[task].to(
                    device, non_blocking=True) for task in tasks}

                optimizer_arch.zero_grad()

                for task in tasks:
                    # many images don't have human parts annotations, skip those
                    uniq = torch.unique(target_search[task])
                    if len(uniq) == 1 and uniq[0] == 255:
                        continue

                    output = model(inputs_search, task=task)
                    tot_loss = loss(output, target_search, task=task)
                    tot_loss.backward()

                    arch_loss += tot_loss.item()
                    arch_counter += 1

                optimizer_arch.step()

            # we reset the main optimizer state because arch has changed
            optimizer_weight.state = defaultdict(dict)

            # we should reset bn running stats
            model.unfreeze_encoder_bn_running_stats()
            model.reset_encoder_bn_running_stats()

        for batch_idx, samples in enumerate(trainloader_weight):

            inputs = samples['image'].to(device, non_blocking=True)
            target = {task: samples[task].to(
                device, non_blocking=True) for task in tasks}

            current_loss = 0
            counter = 0

            for task in tasks:
                # many images don't have human parts annotations, skip those
                uniq = torch.unique(target[task])
                if len(uniq) == 1 and uniq[0] == 255:
                    continue

                optimizer_weight.zero_grad()

                output = model(inputs, task=task)
                tot_loss = loss(
                    output, target, task=task, omit_resource=True)
                tot_loss.backward()

                optimizer_weight.step()

                current_loss += tot_loss.item()
                counter += 1

            if (batch_idx + 1) % 100 == 0:
                n_iter = (epoch - 1) * iter_per_epoch + batch_idx + 1
                print('Train Iterations: {}, Loss: {:.4f}'.format(utils.progress(n_iter,
                                                                                 total_iter),
                                                                  current_loss / counter))
                writer.add_scalar(
                    'loss_current', current_loss / counter, n_iter)
                writer.add_scalar(
                    'arch_loss', arch_loss / max(1, arch_counter), n_iter)
                writer.add_scalar('gumbel_temp', model.gumbel_temp, n_iter)
                for name, param in model.named_arch_parameters():
                    writer.add_image(name, torch.nn.functional.softmax(
                        param.data, dim=-1), n_iter, dataformats='HW')

        # save model
        state = {
            'state_dict': model.state_dict(),
            'tasks': tasks,
            'epoch': epoch,
            'optimizer_weight': optimizer_weight.state_dict(),
            'optimizer_arch': optimizer_arch.state_dict(),
        }
        torch.save(state, Path(exp_dir) / 'checkpoint.pth')

    branch_config = model.get_branch_config()
    utils.write_json({'config': branch_config},
                     Path(exp_dir) / 'branch_config.json')


def train_branched(local_rank,
                   world_size,
                   device,
                   start_epoch,
                   max_epochs,
                   tasks,
                   trainloader,
                   testloader,
                   model,
                   loss,
                   optimizer,
                   scheduler,
                   metrics_dict,
                   exp_dir):

    writer = SummaryWriter(log_dir=exp_dir) if local_rank == 0 else None

    iter_per_epoch = len(
        trainloader.dataset) // (trainloader.batch_size * max(1, world_size))
    total_iter = iter_per_epoch * max_epochs

    for epoch in range(start_epoch, max_epochs + 1):
        model.train()

        # test_branched(device, tasks, testloader, model, metrics_dict, exp_dir)

        if world_size > 1:
            trainloader.sampler.set_epoch(epoch)

        for batch_idx, samples in enumerate(trainloader):

            inputs = samples['image'].to(device, non_blocking=True)
            target = {task: samples[task].to(
                device, non_blocking=True) for task in tasks}

            optimizer.zero_grad()

            output = model(inputs)
            tot_loss = loss(output, target)
            tot_loss.backward()

            optimizer.step()

            if (batch_idx + 1) % 100 == 0 and local_rank == 0:
                current_loss = tot_loss.item()
                n_iter = (epoch - 1) * iter_per_epoch + batch_idx + 1
                print('Train Iterations: {}, Loss: {}'.format(utils.progress(n_iter, total_iter),
                                                              current_loss))
                writer.add_scalar('loss_current', current_loss, n_iter)
                writer.add_scalar(
                    'learning_rate', optimizer.param_groups[0]['lr'], n_iter)

        scheduler.step()

        if epoch % 20 == 0 or epoch == max_epochs:
            # evaluate the model
            test_branched(device, tasks, testloader, model, metrics_dict, epoch)

        # if local_rank == 0:
        #     # save model
        #     state = {
        #         'state_dict': model.state_dict(),
        #         'tasks': tasks,
        #         'branch_config': model.branch_config,
        #         'epoch': epoch,
        #         'optimizer': optimizer.state_dict(),
        #         'scheduler': scheduler.state_dict()
        #     }
        #     torch.save(state, Path(exp_dir) / 'checkpoint.pth')

    # state = {
    #     'state_dict': model.state_dict(),
    #     'tasks': tasks,
    #     # 'branch_config': model.branch_config,
    #     'optimizer': optimizer.state_dict(),
    #     'scheduler': scheduler.state_dict()
    # }
    # torch.save(state, Path(exp_dir) / 'checkpoint.pth')

def train(opt,
        local_rank,
        world_size,
        device,
        start_epoch,
        max_epochs,
        tasks,
        trainloader,
        testloader,
        model,
        loss,
        optimizer,
        scheduler,
        metrics_dict,
        exp_dir):

    def graddrop(grads):
        P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1) + 1e-8))
        U = torch.rand_like(grads[:, 0])
        M = P.gt(U).view(-1, 1) * grads.gt(0) + P.lt(U).view(-1, 1) * grads.lt(0)
        g = (grads * M.float()).mean(1)
        return g

    def mgd(grads):
        grads_cpu = grads.t().cpu()
        sol, min_norm = MinNormSolver.find_min_norm_element([grads_cpu[t] for t in range(grads.shape[-1])])
        w = torch.FloatTensor(sol).to(grads.device)
        g = grads.mm(w.view(-1, 1)).view(-1)
        return g

    def pcgrad(grads, rng, n_tasks=4):
        grad_vec = grads.t()

        shuffled_task_indices = np.zeros((n_tasks, n_tasks - 1), dtype=int)
        for i in range(n_tasks):
            task_indices = np.arange(n_tasks)
            task_indices[i] = task_indices[-1]
            shuffled_task_indices[i] = task_indices[:-1]
            rng.shuffle(shuffled_task_indices[i])
        shuffled_task_indices = shuffled_task_indices.T

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )  # num_tasks x dim
        modified_grad_vec = deepcopy(grad_vec)
        for task_indices in shuffled_task_indices:
            normalized_shuffled_grad = normalized_grad_vec[
                task_indices
            ]  # num_tasks x dim
            dot = (modified_grad_vec * normalized_shuffled_grad).sum(
                dim=1, keepdim=True
            )  # num_tasks x dim
            modified_grad_vec -= torch.clamp_max(dot, 0) * normalized_shuffled_grad
        g = modified_grad_vec.mean(dim=0)
        return g

    def cagrad(grads, alpha=0.5, rescale=1, n_tasks=4):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(n_tasks) / n_tasks
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, n_tasks).dot(A).dot(b.reshape(n_tasks, 1)) + c * np.sqrt(
                x.reshape(1, n_tasks).dot(A).dot(x.reshape(n_tasks, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)
        g = grads.mean(1) + lmbda * gw
        if rescale == 0:
            return g
        elif rescale == 1:
            return g / (1 + alpha ** 2)
        else:
            return g / (1 + alpha)

    def grad2vec(m, grads, grad_dims, task):
        # store the gradients
        grads[:, task].fill_(0.0)
        cnt = 0
        for name, p in m.shared_parameters().items():
            grad = p.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone()
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                grads[beg:en, task].copy_(grad_cur.data.view(-1))
            cnt += 1

    def grad2vec_list(m):
        grad_list = []
        for name, param in m.shared_parameters().items():
            grad = param.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone().view(-1)
                grad_list.append(grad_cur)
        return grad_list

    def split_layer(grad_list, name_dict):
        grad_new = []
        for key, value in name_dict.items():
            grad = [grad_list[i] for i in value]
            grad = torch.cat(grad)
            grad_new.append(grad)

        return grad_new

    def get_layer_dict(m):
        shared_parameters = m.shared_parameters()
        name_list = list(shared_parameters.keys())

        layer_dict = {}
        for i, name in enumerate(name_list):
            if '.weight' in name:
                name = name.replace('.weight', '')
            elif '.bias' in name:
                name = name.replace('.bias', '')

            if name not in layer_dict:
                layer_dict[name] = [i]
            else:
                layer_dict[name].append(i)

        return layer_dict

    def reshape_grad(g, grad_dims):
        grad_all = []
        for i in range(len(grad_dims)):
            beg = 0 if i == 0 else sum(grad_dims[:i])
            en = sum(grad_dims[:i+1])
            this_grad = g[beg:en].data.detach().clone()
            grad_all.append(this_grad)
        return grad_all

    def overwrite_grad(m, newgrad, grad_dims, n_tasks=2):
        newgrad = newgrad * n_tasks  # to match the sum loss
        cnt = 0
        for name, param in m.shared_parameters().items():
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(param.data.size())
            param.grad = this_grad.data.clone()
            cnt += 1

    writer = SummaryWriter(log_dir=exp_dir) if local_rank == 0 else None

    rng = np.random.default_rng()

    grad_dims = []
    for key, param in model.shared_parameters().items():
        grad_dims.append(param.data.numel())

    layer_dict = get_layer_dict(model)
    layer_name = list(layer_dict.keys())

    layer_wise_angle = OrderedDict()
    layer_wise_task_angle = OrderedDict()


    for name in layer_name:
        layer_wise_angle[name] = []
        layer_wise_task_angle[name] = []

    iter_per_epoch = len(
        trainloader.dataset) // (trainloader.batch_size * max(1, world_size))
    total_iter = iter_per_epoch * max_epochs

    for epoch in range(start_epoch, max_epochs + 1):
        model.train()

        if world_size > 1:
            trainloader.sampler.set_epoch(epoch)

        for batch_idx, samples in enumerate(trainloader):
            grads = torch.Tensor(sum(grad_dims), len(tasks)).cuda()

            inputs = samples['image'].to(device, non_blocking=True)
            target = {task: samples[task].to(
                device, non_blocking=True) for task in tasks}

            optimizer.zero_grad()

            output = model(inputs)
            losses = loss(output, target)

            if opt.method == 'pcgrad':
                n_tasks = len(tasks)
                for i in range(n_tasks):
                    if i < n_tasks - 1:
                        losses[i].backward(retain_graph=True)
                    else:
                        losses[i].backward()
                    grad2vec(model, grads, grad_dims, i)
                    model.zero_grad_shared_modules()
                g = pcgrad(grads, rng, n_tasks=n_tasks)
                overwrite_grad(model, g, grad_dims)
                optimizer.step()
                del g

            elif opt.method == 'nothing':
                tot_loss = 0.0
                for l in losses:
                    tot_loss += l

                tot_loss.backward()
                optimizer.step()
                del tot_loss

            elif opt.method == "cagrad":
                n_tasks = len(tasks)
                for i in range(n_tasks):
                    if i < n_tasks - 1:
                        losses[i].backward(retain_graph=True)
                    else:
                        losses[i].backward()

                    grad2vec(model, grads, grad_dims, i)
                    model.zero_grad_shared_modules()

                g = cagrad(grads, opt.alpha, rescale=1, n_tasks=n_tasks)

                overwrite_grad(model, g, grad_dims)

                optimizer.step()

                del g
            if opt.method  == "graddrop":
                n_tasks = len(tasks)
                for i in range(n_tasks):
                    if i < n_tasks - 1:
                        losses[i].backward(retain_graph=True)
                    else:
                        losses[i].backward()

                    grad2vec(model, grads, grad_dims, i)
                    model.zero_grad_shared_modules()

                # g1 = grads[:, 0].data.clone()
                # g2 = grads[:, 1].data.clone()
                # g3 = grads[:, 2].data.clone()
                #
                # angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = graddrop(grads)
                overwrite_grad(model, g, grad_dims)
                optimizer.step()
                del g

            elif opt.method  == "mgd":
                n_tasks = len(tasks)
                for i in range(n_tasks):
                    if i < n_tasks - 1:
                        losses[i].backward(retain_graph=True)
                    else:
                        losses[i].backward()

                    grad2vec(model, grads, grad_dims, i)
                    model.zero_grad_shared_modules()

                # g1 = grads[:, 0].data.clone()
                # g2 = grads[:, 1].data.clone()
                # g3 = grads[:, 2].data.clone()
                #
                # angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = mgd(grads)
                overwrite_grad(model, g, grad_dims)
                optimizer.step()
                del g

            elif opt.method == 'branch_layer' or opt.method == 'bmtas_branch':
                grad_all = []
                n_tasks = len(tasks)
                for i in range(n_tasks):
                    if i < n_tasks - 1:
                        losses[i].backward(retain_graph=True)
                    else:
                        losses[i].backward()
                    grad2vec(model, grads, grad_dims, i)
                    grad = grad2vec_list(model)
                    grad = split_layer(grad_list=grad, name_dict=layer_dict)
                    grad_all.append(grad)
                    model.zero_grad_shared_modules()

                if opt.method_sub == 'nothing':
                    g = grads[:, 0]
                    for i in range(1, n_tasks):
                        g = g + grads[:, i]
                    g = g / n_tasks
                elif opt.method_sub == 'cagrad':
                    g = cagrad(grads, opt.alpha, rescale=1)
                else:
                    raise ValueError(f'Error: {opt.method_sub}')

                target_g_list = reshape_grad(g, grad_dims)
                target_g_list = split_layer(target_g_list, name_dict=layer_dict)

                length = len(grad_all[0])

                pair_grad = []
                for i in range(length):
                    temp = []
                    for j in range(n_tasks):
                        temp.append(grad_all[j][i])
                    temp = torch.stack(temp)
                    pair_grad.append(temp)

                # cos_saved = []
                if epoch < opt.start_epoch:
                    # for i in range(length):
                    #     task_diff_list = LFT_task_cos_diff(grad_all=grad_all, grad_target=target_g_list, n_tasks=n_tasks,
                    #                                    pos=i)
                    #     diff_list = LFT_cos_diff(grad_all=grad_all, n_tasks=n_tasks, pos=i)
                    #     task_diff[layer_name[i]].append(task_diff_list)
                    #     diff[layer_name[i]].append(diff_list)

                    for i, pair in enumerate(pair_grad):
                        layer_wise_task_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
                        # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
                        layer_wise_cos = pair_cos(pair).cpu()
                        # layer_wise_task_dot = pair_dot(pair).cpu()

                        layer_wise_angle[layer_name[i]].append(layer_wise_cos)
                        # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
                        layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)
                        # layer_wise_task_dot_value[layer_name[i]].append(layer_wise_task_dot)

                overwrite_grad(model, g, grad_dims)
                optimizer.step()

                del g
                del grad
                del grads
                del grad_all
                del target_g_list
                del pair_grad

            if (batch_idx + 1) % 100 == 0 and local_rank == 0:
                tot_loss = sum(losses)
                current_loss = tot_loss.item()
                n_iter = (epoch - 1) * iter_per_epoch + batch_idx + 1
                print('Train Iterations: {}, Loss: {}'.format(utils.progress(n_iter, total_iter),
                                                              current_loss))
                writer.add_scalar('loss_current', current_loss, n_iter)
                writer.add_scalar(
                    'learning_rate', optimizer.param_groups[0]['lr'], n_iter)

                del tot_loss

            del losses
            del output
            del inputs
            del target

        scheduler.step()

        if epoch % 20 == 0 or epoch == max_epochs:
            # evaluate the model
            test_branched(device, tasks, testloader, model, metrics_dict, epoch)

        if epoch == max_epochs and (opt.method == 'branch_layer' or opt.method == 'bmtas_branch'):
            # saved_dict = {'task_diff': task_diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_task_diff_data_cityspace.pt')
            # saved_dict = {'diff': diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_diff_data_cityspace.pt')
            saved_dict = {'cos': layer_wise_angle}
            torch.save(saved_dict, f'./saved/{opt.method}{opt.seed}{opt.method_sub}_{opt.optimizer}_{epoch}_lw_cos.pt')
            saved_dict = {'task_cos': layer_wise_task_angle}
            torch.save(saved_dict, f'./saved/{opt.method}{opt.seed}{opt.method_sub}_{opt.optimizer}_{epoch}_lw_task_cos.pt')

@torch.no_grad()
def test_branched(device, tasks, testloader, model, metrics_dict, epoch):

    model.eval()

    # get resources
    sample = next(iter(testloader))
    height, width = sample['image'].shape[-2:]
    gflops = resources.compute_gflops(model, device=device,
                                      in_shape=(1, 3, height, width))
    params = resources.count_parameters(model)
    results = {
        'gmadds': gflops / 2.0,
        'mparams': params / 1e6
    }

    # seg_conf = metrics.ConfMatrix(num_classes=21)

    for idx, samples in enumerate(testloader):

        inputs = samples['image'].to(device, non_blocking=True)
        target = {task: samples[task].to(
            device, non_blocking=True) for task in tasks}
        im_size = tuple(x.item() for x in samples['meta']['im_size'])
        im_name = samples['meta']['image'][0]

        output = model(inputs)

        for task in tasks:

            uniq = torch.unique(target[task])
            if len(uniq) == 1 and uniq[0] == 255:
                continue

            metrics_dict[task].update(
                output[task], target[task], im_size, im_name)

        if (idx + 1) % 100 == 0:
            print('{} / {} images done.'.format(idx + 1, len(testloader)))

    for task in tasks:
        # results['_'.join([task, metrics_dict[task].__class__.__name__])
        #         ] = metrics_dict[task].get_score()
        scores = metrics_dict[task].get_score()
        values = list(scores.values())
        keys = list(scores.keys())
        str = ''
        for key in keys:
            str += f'{key},'
        str += ': '
        for values in values:
            str += f'{values} '
        print(str)
        scores['epoch'] = epoch
        wandb.log(scores)

    # mean_iou, acc = seg_conf.get_metrics()

    # utils.write_json(results, Path(exp_dir) / 'eval.json')

    for task in tasks:
        metrics_dict[task].reset()
