import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
import scipy
import scipy.linalg
from scipy.spatial.distance import cdist
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network.mlp import MNISTMLP
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class UDIL(Finetune):
    def __init__(self, args):
        super().__init__(args)

        self._discriminator_k = args.get("discriminator_k", 1)
        self._task_weight_k = args.get("task_weight_k", 1)
        
        self._task_weight_lr = args.get("task_weight_lr", 1e-3)
        self._discriminator_lr = args.get("discriminator_lr", 1e-3)
        
        self._discriminator = MNISTMLP(self._network.features_dim, 800, self._n_tasks).to(self._device)
        self._discriminator_optimizer = Adam(self._discriminator.parameters(), lr=self._discriminator_lr)
        
        self._past_errs = []
        
        self._supcon_loss = args.get("supcon_loss")
        self._encoder_loss = args.get("encoder_loss")
        self._distil_loss = args.get("distill_loss")
        self._general_loss = args.get("general_loss")
        
    def _before_task(self):       
        if self._task > 0:
            # setup task weights
            self._task_logits = torch.zeros((3, self._task), requires_grad=True, device=self._device)
            self._task_weight_optimizer = Adam([self._task_logits], lr=self._task_weight_lr)
        super()._before_task()
            
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network
        
        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target").to(self._device)
                domains = input_dict.pop("task_id").to(self._device)
                inputs = {key: input_dict[key].to(self._device) for key in input_dict}

                bs = [len(domains)]
                
                if self._task > 0:
                    old_batch = (domains != self._task)
                    bs = [len(domains)-old_batch.sum(), old_batch.sum()]

                    targets, mem_targets = torch.split(targets, bs)
                    domains, mem_domains = torch.split(domains, bs)
                    mem_inputs = {key: inputs[key][-bs[1]:] for key in inputs}
                    inputs = {key: inputs[key][:bs[0]] for key in inputs}

                    err_rates, loss_disc = self._update_discriminator(inputs, targets, domains, mem_inputs, mem_targets, mem_domains)
                    loss_task = self._update_task_weights(inputs, targets, domains, mem_inputs, mem_targets, mem_domains, hdivs=2*(1-2*err_rates))
                
                    targets = torch.cat([targets, mem_targets], dim=0)
                    domains = torch.cat([domains, mem_domains], dim=0)
                    inputs = {key: torch.cat([inputs[key], mem_inputs[key]], dim=0) for key in inputs}
                    
                loss = self._forward_loss(
                    training_network,
                    inputs, targets, domains, batch_size=bs
                )

                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()
                
    def _forward_loss(
        self,
        training_network,
        inputs,
        targets,
        domains,
        batch_size,
        **kwargs
    ):
        
        if self._task == 0:
            self._optimizer.zero_grad()
            outputs = self._network(inputs)
            loss = F.cross_entropy(outputs["logits"], targets)   
            self._metrics["clf"] += loss.item() 
            
            if self._supcon_loss.get('first_domain', False):
                loss_supcon = self._compute_supcon_loss(
                        cur_feats=outputs["features_fused"], 
                        cur_labels=targets, 
                        past_feats=None, 
                        past_labels=None
                    )
                self._metrics["con"] += loss_supcon.item() 
                loss += loss_supcon
            loss.backward()
            self._optimizer.step()
                
        else:
            cur_targets, past_targets = torch.split(targets, batch_size)
            cur_domains, past_domains = torch.split(domains, batch_size)
            
            with torch.no_grad():
                old_outputs = self._old_network(inputs)
            old_cur_logits, old_past_logits = torch.split(old_outputs['logits'], batch_size)
            old_cur_feats, old_past_feats = torch.split(old_outputs['features_fused'], batch_size)
            
            self._optimizer.zero_grad()
            outputs = self._network(inputs)
            cur_logits, past_logits = torch.split(outputs['logits'], batch_size)
            cur_feats, past_feats = torch.split(outputs['features_fused'], batch_size)
            
            loss_cur_erm = F.cross_entropy(cur_logits, cur_targets)
            self._metrics["clf"] += loss_cur_erm.item()
            
            loss_cur_kd = self._compute_current_kd(cur_logits, old_cur_logits, threshold=self._distil_loss['threshold'])
            self._metrics["dis"] += loss_cur_kd.item()
            
            loss_past = self._compute_past_losses(
                logits_p=past_logits, 
                logits_t=old_past_logits, 
                labels=past_targets, 
                domain_ids=past_domains,
                loss_form=self._distil_loss['form']
            )
            self._metrics["pas"] += loss_past.item() 
            
            loss_supcon = self._compute_supcon_loss(
                cur_feats=cur_feats, 
                past_feats=past_feats, 
                cur_labels=cur_targets, 
                past_labels=past_targets, 
                cur_domains=cur_domains,
                past_domains=past_domains
            )
            self._metrics["con"] += loss_supcon.item() 
                
            loss = (loss_cur_erm + loss_cur_kd + loss_past + loss_supcon)
            loss.backward()
            self._optimizer.step()
            
            self._optimizer.zero_grad()
            outputs = self._network(inputs)
            cur_logits, past_logits = torch.split(outputs['logits'], batch_size)
            cur_feats, past_feats = torch.split(outputs['features_fused'], batch_size)
            
            loss_encoder = self._compute_encoder_loss(cur_feats, past_feats, old_past_feats, past_domain_ids=past_domains)
            self._metrics["enc"] += loss_encoder.item() 
            
            loss_encoder.backward()
            self._optimizer.step()

        self._metrics["loss"] += loss.item()        
        self._metrics["acc"] += (targets == outputs['logits'].argmax(axis=1)).float().mean().item()

        return loss                        
    
    def _update_discriminator(self, cur_input, cur_target, cur_domain, past_input, past_target, past_domain):
        task_weights = self._task_weight

        cur_sample_weights = torch.ones((cur_domain.shape[0], ), device=self._device) / cur_domain.shape[0]

        beta_prime = (task_weights[1]/task_weights[1].sum()).detach().clone()            
        unique_labels, past_domain_cnts = torch.unique(past_domain, return_counts=True, sorted=True)
        unique_labels = unique_labels.type(torch.long).to(self._device)
        past_domain_cnts = past_domain_cnts.type(torch.float).to(self._device)
        full_cnts = torch.zeros_like(beta_prime, device=self._device)
        full_cnts[unique_labels] = past_domain_cnts.type(torch.float)
        past_samples_weights = beta_prime[past_domain] / full_cnts[past_domain]

        inputs = {key: torch.vstack([cur_input[key], past_input[key]]) for key in cur_input}
        domains = torch.hstack([cur_domain, past_domain])
        sample_weights = torch.hstack([cur_sample_weights, past_samples_weights])

        with torch.no_grad():
            features = self._network(inputs)['features_fused']
            
        for _ in range(self._discriminator_k):            
            self._discriminator_optimizer.zero_grad()
            logits = self._discriminator(features)
            batch_loss = F.cross_entropy(logits[:, :self._task+1], domains)
            
            loss = (batch_loss * sample_weights).sum()            
            loss.backward()
            self._discriminator_optimizer.step()
        
        # return the final estimated error rate for each Hdiv_i. 
        # error rate: 0-1 loss
        errs = []
        for i in range(self._task):
            binary_preds = torch.where(torch.argmax(logits[:, [i, self._task]], 1) == 0, i, self._task)
            incorrects = binary_preds != domains
            past_inds, cur_inds = domains == i, domains == self._task
            past_incorrects, cur_incorrects = incorrects[past_inds].sum() / past_inds.shape[0], incorrects[cur_inds].sum() / cur_inds.shape[0]
            error_rate = ((past_incorrects+cur_incorrects)/2).item()
            err = min(error_rate, 1-error_rate) # Is this correct? 
            errs.append(err)

        return torch.tensor(np.array(errs)).to(self._device), loss.item()
    
    def _update_task_weights(self, cur_input, cur_target, cur_domain, past_input, past_target, past_domain, hdivs):
        self._network.eval()
        
        inputs = {key: torch.cat([cur_input[key], past_input[key]]) for key in cur_input}
        with torch.no_grad():
            outputs = self._network(inputs)
            logits = outputs['logits']
            feats = outputs['features_fused']
            
            old_outputs = self._old_network(inputs)
            old_logits = old_outputs['logits']
            
        cur_logits, past_logits = torch.split(logits, [cur_domain.shape[0], past_domain.shape[0]])
        cur_feats, past_feats = torch.split(feats, [cur_domain.shape[0], past_domain.shape[0]])

        old_cur_logits, old_past_logits = torch.split(old_logits, [cur_domain.shape[0], past_domain.shape[0]])
        
        for _ in range(self._task_weight_k):
            self._task_weight_optimizer.zero_grad()

            loss_cur_kd = self._compute_current_kd(cur_logits, old_cur_logits, loss_form='0-1')
            loss_past = self._compute_past_losses(
                logits_p=past_logits, 
                logits_t=old_past_logits, 
                labels=past_target, 
                domain_ids=past_domain,
                loss_form='0-1'
            )
            loss_tradeoff = self._compute_tradeoff(hdivs=hdivs)
            loss_generalizaton = self._generalization_error()
            loss = loss_cur_kd + loss_past + loss_tradeoff + loss_generalizaton

            loss.backward()
            self._task_weight_optimizer.step()
        self._network.train() 
        return loss.item()

            
    def _compute_current_kd(self, logits, old_logits, loss_form='ce', threshold=10.):
        """E_{Dt}(h, Ht-1)"""
        task_weights = self._task_weight
        
        if loss_form == 'ce':
            loss_kd = torch.mean(-(torch.softmax(old_logits, 1) * torch.log_softmax(logits, 1)).sum(1))
            # added a threshold mechanism that prevents pre-alignment LwF loss.
            return loss_kd * task_weights[1].sum() * float(loss_kd.item() < threshold)
        # 0-1 loss for updating beta.
        elif loss_form == '0-1':
            loss_kd = (torch.argmax(logits,1) != torch.argmax(old_logits,1)).type(torch.float).mean()
            return loss_kd * task_weights[1].sum()
        else:
            supported_losses = ['ce', '0-1']
            raise NotImplementedError(f"Loss form '{loss_form}' not supported; currently supported: {supported_losses}")
            
    def _compute_past_losses(self, logits_p, logits_t, labels, domain_ids, loss_form='ce'):
        task_weights = self._task_weight

        # gamma to weight E_{Di}(h)
        alpha, gamma = task_weights[0], task_weights[2]

        unique_labels, past_domain_cnts = torch.unique(domain_ids, return_counts=True, sorted=True)
        unique_labels = unique_labels.type(torch.long).to(self._device)
        past_domain_cnts = past_domain_cnts.type(torch.float).to(self._device)
        full_cnts = torch.zeros_like(gamma, device=self._device)
        full_cnts[unique_labels] = past_domain_cnts.type(torch.float)

        alpha_weights = alpha[domain_ids] / full_cnts[domain_ids] # alpha_i/N_i
        gamma_weights = gamma[domain_ids] / full_cnts[domain_ids] # gamma_i/N_i

        if loss_form == 'ce':
            loss_ce = (gamma_weights * F.cross_entropy(logits_p, labels, reduction='none')).sum()
            loss_kd = (alpha_weights * -(torch.softmax(logits_t, 1) * torch.log_softmax(logits_p, 1)).sum(1)).sum()
        elif loss_form == 'l2':
            loss_ce = (gamma_weights * F.cross_entropy(logits_p, labels, reduction='none')).sum()
            loss_kd = (alpha_weights * F.mse_loss(logits_p, logits_t, reduction='none').sum(dim=1)).sum() / logits_p.shape[1]
        elif loss_form == '0-1':
            loss_ce = (gamma_weights * (torch.argmax(logits_p, 1) != labels).type(torch.float)).sum()
            loss_kd = (alpha_weights * (torch.argmax(logits_p, 1) != torch.argmax(logits_t,1)).type(torch.float)).sum()
        else:
            supported_losses = ['ce', '0-1', 'l2']
            raise NotImplementedError(f"Loss form '{loss_form}' not supported; currently supported: {supported_losses}")
        return loss_ce + loss_kd
            
    def _compute_tradeoff(self, hdivs):
        task_weights = self._task_weight
        beta = task_weights[1]
        past_errs = torch.tensor(self._past_errs, requires_grad=False, device=self._device, dtype=torch.float)
        alpha_beta_sum = task_weights[:2].sum(0)

        return 0.5 * (beta * hdivs).sum() + alpha_beta_sum.dot(past_errs)
        
    def _generalization_error(self):
        task_weights = self._task_weight
        num_samples = torch.tensor(sum([len(d) for d in self.inc_dataset.exemplar_dataset])).to(self._device)
        
        term1 = (1. + task_weights[1].sum())**2 / len(self.inc_dataset.cur_train_dataset)
        term2 = ((task_weights[0] + task_weights[2])**2 / num_samples).sum()

        return self._general_loss['lambda'] * torch.sqrt(term1 + term2)
    
    def _compute_supcon_loss(self, cur_feats, cur_labels, cur_domains=None, past_feats=None, past_labels=None, past_domains=None):
        domains = None
        if past_feats is None or past_labels is None:
            feats = cur_feats
            labels = cur_labels
        else:
            feats = torch.cat([cur_feats, past_feats])
            labels = torch.cat([cur_labels, past_labels])
            if cur_domains is not None and past_domains is not None:
                domains = torch.cat([cur_domains, past_domains])

        # normalized.
        if self._supcon_loss.get('normalize', False):
            feats = torch.nn.functional.normalize(feats).unsqueeze(1)
        else:
            feats = feats.unsqueeze(1)

        if self._supcon_loss.get('cross_domain', False):
            loss_supcon = losses.cross_domain_sup_con(feats, labels, domains, 
                                                      base_temperature=self._supcon_loss['temperature'], 
                                                      temperature=self._supcon_loss['temperature'], 
                                                      loss_form=self._supcon_loss['sim'])
        else:
            loss_supcon = losses.cross_domain_sup_con(feats, labels, 
                                                      base_temperature=self._supcon_loss['temperature'], 
                                                      temperature=self._supcon_loss['temperature'], 
                                                      loss_form=self._supcon_loss['sim'])

        return self._supcon_loss['lambda'] * loss_supcon
    
    def _compute_encoder_loss(self, cur_feats, past_feats, past_feats_stored, past_domain_ids, align_part='both'):
        """min encoder w.r.t. the H-divergence"""
        n_prev = self._task_weight.shape[1]
        task_weights = self._task_weight

        ###############################################
        # Part 1: adversarial training against the discriminator
        ###############################################
        # fead-forward to discriminator.
        combined_feats = torch.cat([cur_feats, past_feats]) # to avoid BN in discriminator cheating
        logits = self._discriminator(combined_feats)
        masked_logits = logits[:, :self._task+1]

        # when calculating the adversarial loss, we minimize the largest logits, i.e., argmax.
        labels = torch.argmax(masked_logits, dim=1)

        # align both embedding distributions together.        
        if align_part == 'both':
            cur_sample_weights = torch.ones((cur_feats.shape[0], ), device=self._device) / cur_feats.shape[0]
            
            # get past sample weights.
            beta_prime = (task_weights[1] / task_weights[1].sum()).detach().clone() # grad not passed through beta.
            unique_labels, past_domain_cnts = torch.unique(past_domain_ids, return_counts=True, sorted=True)
            unique_labels = unique_labels.type(torch.long).to(self._device)
            past_domain_cnts = past_domain_cnts.type(torch.float).to(self._device)
            full_cnts = torch.zeros_like(beta_prime, device=self._device)
            full_cnts[unique_labels] = past_domain_cnts.type(torch.float)
            past_samples_weights = beta_prime[past_domain_ids] / full_cnts[past_domain_ids] # beta_i/N_i
            # past_samples_weights = beta_prime[past_domain_ids] / past_feats.shape[0] # ADDED: make it balanced binary classification.

            sample_weights = torch.cat([cur_sample_weights, past_samples_weights])
            # brute-force normalize the loss.
            not_normed_loss = (sample_weights * F.cross_entropy(masked_logits, labels, reduction='none')).sum()
            loss = -self._encoder_loss['lambda'] * not_normed_loss
        # align the current domain's embedding distribution to the memory distribution
        elif align_part == 'cur':
            cur_logits, _ = torch.split(logits, [cur_feats.shape[0], past_feats.shape[0]])
            cur_labels = torch.ones((cur_feats.shape[0],), dtype=torch.long, device=self._device)
            cur_sample_weights = torch.ones((cur_feats.shape[0], ), device=self._device) / cur_feats.shape[0]
            loss = -self._encoder_loss['lambda'] * (cur_sample_weights * F.cross_entropy(cur_logits, cur_labels, reduction='none')).sum() # already -log(D(G(x))), which is more stable.
        
        ###############################################
        # Part 2: encoder results on the past domains should be stable.
        ###############################################
        loss += self._encoder_loss['mu'] * F.mse_loss(past_feats, past_feats_stored, reduction='sum') / past_feats.shape[0]

        return loss

    @property
    def _task_weight(self):
        base = torch.max(self._task_logits, 0, keepdim=True).values.detach()
        return F.softmax(self._task_logits-base, 0)
    
    def _after_task_intensive(self, inc_dataset):
        self._past_errs = []        
        with torch.no_grad():
            # memory of the past domains.
            if self._task > 0:
                errors = []
                domains = []
                for data in inc_dataset.get_cur_exemplar_loader(drop_last=False):
                    target = data.pop("target")
                    domain = data.pop("task_id")
                    inputs = {key: item.to(self._device) for key, item in data.items()}
                    logits = self._network(inputs)["logits"].cpu()

                    errors.append((torch.argmax(logits, 1) != target).numpy())
                    domains.append(domain.numpy())

                errors = np.concatenate(errors)
                domains = np.concatenate(domains)
                for t in range(self._task):
                    idx = domains==t
                    self._past_errs.append(errors[idx].sum()/idx.sum())

            # current domain estimation is more accurate using the whole current domain data.
            error, total = 0, 0
            for data in inc_dataset.get_cur_train_loader(shuffle=False, num_workers=8):
                target = data.pop("target")
                domain = data.pop("task_id")
                inputs = {key: item.to(self._device) for key, item in data.items()}
                logits = self._network(inputs)["logits"].cpu()

                error += (torch.argmax(logits, 1) != target).sum().item()
                total += domain.shape[0]
            self._past_errs.append(error/total)

        inc_dataset.update_exemplar()
                    
    def _after_task(self):
        self._old_network = self._network.copy().freeze().to(self._device)
        self._network.on_task_end()    
        
    def save_parameters(self, directory, run_id):
        super().save_parameters(directory, run_id)
        
        path = os.path.join(directory, f"net_{run_id}_task_{self._task}.pth")
        save_states = torch.load(path)

        save_states[f"discriminator"] = self._discriminator.state_dict()
        if self._task > 0:
            save_states[f"task_logits"] = self._task_logits
        
        torch.save(save_states, path)



