import json
import time
import torch
import wandb
import torch.nn as nn
import itertools

from torch.nn.modules.loss import CrossEntropyLoss

from pathlib import Path
from collections import OrderedDict, defaultdict

def read_json(fname):
    fname = Path(fname)
    with fname.open('rt') as handle:
        return json.load(handle, object_hook=OrderedDict)

def write_json(content, fname):
    fname = Path(fname)
    with fname.open('wt') as handle:
        json.dump(content, handle, indent=4, sort_keys=False)

class Timer():
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return f'{x / 3600:.1f}h'
        if x >= 60:
            return f'{round(x / 60)}m'
        return f'{x}s'

class Averager():

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.data = []

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def add(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.data.append(val)

    def item(self):
        return self.avg

    def obtain_data(self):
        return self.data

    def __len__(self):
        return len(self.data)


class BranchingLoss(nn.Module):
    """
    Proxyless resource loss for generating efficient branched networks. The loss is equal to
    the expected value of FLOPS of a branched architecture sampled from the supergraph. The
    FLOPS are calculated with the help of a look-up table.
    """

    def __init__(self, model):
        super().__init__()

        self.alphas = model.alphas
        self.n_tasks = len(model.tasks)

        my_partitions = [k for k in self.partition(
            list(range(self.n_tasks))) if not len(k) == self.n_tasks]
        self.total_l = len(model.encoder)  # number of layers in the encoder
        self.total_k = len(my_partitions) + 1  # the number of task groupings

        # we register some buffers that are needed for loss calculation
        self._initialize_weights(my_partitions, model)
        self._initialize_ancestors(my_partitions)
        self._initialize_indices(my_partitions)

    def forward(self):
        total_p = self.alphas[0].new_zeros((self.total_l, self.total_k))
        for l in range(self.total_l):
            # last dim indexes filter group
            sampling_prob = nn.functional.softmax(self.alphas[l], dim=-1)
            for k, (resind, parent) in enumerate(zip(self.resind_buffers(),
                                                     self.ancestor_buffers())):
                # conditional probability of task grouping
                p_cond = torch.sum(torch.prod(
                    torch.take(sampling_prob, resind), dim=1))
                if l > 0:
                    p = p_cond * torch.sum(total_p[l - 1, parent])
                else:
                    p = p_cond
                total_p[l, k] = p
            total_p[l, self.total_k - 1] = 1. - torch.sum(total_p[l])
        loss = torch.sum(total_p * self.weights)
        return loss

    def resind_buffers(self):
        for i in range(self.total_k - 1):
            yield getattr(self, 'resind_{}'.format(i))

    def ancestor_buffers(self):
        for i in range(self.total_k - 1):
            yield getattr(self, 'ancestors_{}'.format(i))

    def partition(self, input_set):
        # credit: https://stackoverflow.com/questions/19368375/set-partitions-in-python
        if len(input_set) == 1:
            yield [input_set]
            return
        first = input_set[0]
        for smaller in self.partition(input_set[1:]):
            # insert `first` in each of the subpartition's subsets
            for n, subset in enumerate(smaller):
                yield smaller[:n] + [[first] + subset] + smaller[n + 1:]
            # put `first` in its own subset
            yield [[first]] + smaller

    def _initialize_weights(self, partitions, model):
        """ Initialize the resource weights for every possible task grouping / partition.
        The weight is equal to the number of branches in that layer (times the FLOPS per branch)
        considering that the tasks would be grouped accordingly. The weights are stored in a buffer.
        """
        flops = model.get_flops()
        gflops = torch.tensor(flops).view(-1, 1).repeat(1, self.total_k) / 1e9
        nr_parts = []
        for k in partitions:
            nr_parts.append(len(k))
        nr_parts.append(self.n_tasks)  # to account for task-specific case
        nr_parts = torch.tensor(nr_parts).repeat(self.total_l, 1)
        weights = gflops * nr_parts
        self.register_buffer('weights', weights)

    def _initialize_indices(self, partitions):
        """ For every possible partition, store the indices for sampling a
        flattened 'probs' matrix leading to that partition in a buffer.
        Example buffer: torch.tensor([[0, 4, 8, 12], [0, 5, 8, 12], [1, 6, 10, 14]])
        """
        for i, k in enumerate(partitions):
            ind_list = []
            for candidate in itertools.product(range(self.n_tasks), repeat=self.n_tasks):
                candidate_partition = defaultdict(list)
                for idx, element in enumerate(candidate):
                    candidate_partition[element].append(idx)
                if sorted(candidate_partition.values()) == sorted(k):
                    ind_list.append(
                        [v + u * self.n_tasks for u, v in enumerate(candidate)])
            self.register_buffer('resind_{}'.format(
                i), torch.tensor(ind_list, dtype=torch.long))

    def _initialize_ancestors(self, partitions):
        """ For every task partition k, find the corresponding ancestor task partitions.
        The ancestor task partitions are the set of partitions of which k is a refinement.
        Registers a buffer for every partition:
        Every buffer contains the indices to the ancestors of that particular partition.
        Example buffer: torch.tensor([0, 3, 4, 5, 11])
        """
        # there might be more efficient ways to do this, but we only need to do it upon init
        def is_refinement(partition_1, partition_2):
            # checks whether partition_1 is a refinement of partition_2
            for part_1 in partition_1:
                if any(set(part_1).issubset(set(part_2)) for part_2 in partition_2):
                    continue
                return False
            return True

        for i, k in enumerate(partitions):
            ancestors = []
            for candidate_idx, candidate in enumerate(partitions):
                if is_refinement(k, candidate):
                    ancestors.append(candidate_idx)
            self.register_buffer('ancestors_{}'.format(
                i), torch.tensor(ancestors, dtype=torch.long))


def train_search(opt,
                 tasks,
                 model,
                 trainloader_arch,
                 trainloader_weight,
                 optimizer_arch,
                 optimizer_weight,
                 exp_dir):

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    wandb.define_metric("epoch")
    wandb.define_metric("Train_loss_avg", step_metric="epoch")
    # wandb.define_metric("Test_loss", step_metric="epoch")
    # wandb.define_metric("Test_acc", step_metric="epoch")

    timer = Timer()
    max_epochs = opt.n_epoch

    ce_loss = CrossEntropyLoss()
    branch_loss = BranchingLoss(model=model).cuda()


    delay_epochs = max_epochs // 20

    for epoch in range(1, max_epochs + 1):
        avg_tr_loss = Averager()
        avg_branch_loss = Averager()

        model.train()
        model.warmup_flag = (epoch <= delay_epochs)

        model.gumbel_temp = min(5.0 - (epoch - delay_epochs - 1)
                                / (max_epochs - delay_epochs - 1) * (5.0 - 0.1), 5.0)

        if epoch > delay_epochs:
            print('modifying architecture...')

            # we reset the arch optimizer state
            optimizer_arch.state = defaultdict(dict)

            # we use current batch statistics in search period
            model.freeze_encoder_bn_running_stats()

            for k, data in enumerate(trainloader_arch):
                x = data[0].cuda()
                ts = data[1].cuda()


                optimizer_arch.zero_grad()

                for task in tasks:
                    if task == 't1':
                        target = ts[:, 0]
                    elif task == 't2':
                        target = ts[:, 1]
                    else:
                        raise ValueError('Error')

                    output = model(x, task=task)
                    output = output[task]

                    tot_loss = ce_loss(output, target)

                    l1 = branch_loss()
                    tot_loss += l1 * opt.resource_loss_weight

                    tot_loss.backward()
                    avg_branch_loss.add(tot_loss.item(), n_samples)

                optimizer_arch.step()

            # we reset the main optimizer state because arch has changed
            optimizer_weight.state = defaultdict(dict)

            # we should reset bn running stats
            model.unfreeze_encoder_bn_running_stats()
            model.reset_encoder_bn_running_stats()

        for batch_idx, data in enumerate(trainloader_weight):

            x = data[0].cuda()
            ts = data[1].cuda()

            n_samples = x.size(0)

            for task in tasks:
                optimizer_weight.zero_grad()
                if task == 't1':
                    target = ts[:, 0]
                elif task == 't2':
                    target = ts[:, 1]
                else:
                    raise ValueError('Error')

                output = model(x, task=task)
                tot_loss = ce_loss(output[task], target)

                tot_loss.backward()

                optimizer_weight.step()

                avg_tr_loss.add(tot_loss, n_samples)

        print(
            f'Epoch:{epoch} | ETA:{timer.measure()}/{timer.measure((epoch + 1) / max_epochs)}')
        print(f'Loss_tr: {avg_tr_loss.item()}, Loss_branch: {avg_branch_loss.item()}')
        log_dict = {
            'epoch': epoch,
            'Tr_Loss': avg_tr_loss.item(),
            'Branch_Loss': avg_branch_loss.item()
        }
        wandb.log(log_dict)

    branch_config = model.get_branch_config()
    write_json({'config': branch_config},
                     Path(exp_dir) / 'branch_config.json')