import math

import torch

from datasets import get_dataset
from utils.buffer import Buffer
from utils.args import *
from models.utils.continual_model import ContinualModel
from copy import deepcopy
import copy
from torch import nn
import os
import torch.nn.functional as F

def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Complementary Learning Systems Based Experience Replay')
    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)

    # Consistency Regularization Weight
    parser.add_argument('--reg_weight', type=float, default=0.1)
    # Stable Model parameters
    parser.add_argument('--ema_model_update_freq', type=float, default=0.70)
    parser.add_argument('--ema_model_alpha', type=float, default=0.999)
    # Training modes
    parser.add_argument('--save_interim', type=int, default=0)
    # Sample Selection
    parser.add_argument('--loss_margin', type=float, default=1)
    parser.add_argument('--loss_alpha', type=float, default=0.99)
    parser.add_argument('--std_margin', type=float, default=1)
    parser.add_argument('--task_warmup', type=int, default=1)

    parser.add_argument('--alpha', type=float,
                        help='Penalty weight.', default=0)
    parser.add_argument('--beta', type=float,
                        help='Penalty weight.', default=0)
    parser.add_argument('--lamb', type=float,
                        help='Penalty weight.', default=0)
    parser.add_argument('--biam', type=int, default=0,
                        help='Reduces forgetting of past tasks by distilling only the current task’s logits during training.')
    parser.add_argument('--newl', type=int, default=0,
                        help='Distills soft labels from the assistant teacher to guide the student.')
    parser.add_argument('--divk', type=int, default=0,
                        help='Combines the assistant and main teacher’s logits for more diverse knowledge transfer')
    parser.add_argument('--bufs', type=int, default=0,
                        help='Refines sample selection based on the agreement between the student and assistant teacher.')
    return parser


# =============================================================================
# Mean-ER
# =============================================================================
class ESMERABLATION(ContinualModel):
    NAME = 'esmer_ablation'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(ESMERABLATION, self).__init__(backbone, loss, args, transform)
        # self.buffer = Buffer(self.args.buffer_size, self.device)
        self.buffer = Buffer(int(self.args.buffer_size), self.device, gpu_id=self.args.gpu_id)
        self.cpt = get_dataset(args).N_CLASSES_PER_TASK

        # Initialize EMA model
        self.ema_model = deepcopy(self.net).to(self.device)
        # set regularization weight
        self.reg_weight = args.reg_weight
        # set parameters for ema model
        self.ema_model_update_freq = args.ema_model_update_freq
        self.ema_model_alpha = args.ema_model_alpha
        # Set loss functions
        self.consistency_loss = nn.MSELoss(reduction='none')
        self.task_loss = nn.CrossEntropyLoss(reduction='none')

        self.current_task = 0
        self.global_step = 0
        self.task_iter = 0

        # Running estimates
        self.loss_running_sum = 0
        self.loss_running_mean = 0
        self.loss_running_std = 0
        self.n_samples_seen = 0
        self.seen_so_far = torch.LongTensor(size=(0,)).to(self.device)

    def observe(self, inputs, labels, not_aug_inputs, epoch=None, task_id_nominal=None, teacher=None, noise=None):
        T = 2
        self.opt.zero_grad()
        self.net.train()
        self.ema_model.train()
        counter, counter2, counter3, counter4  = 0, 0, 0, 0
        mask = None

        loss = 0

        if teacher:
            teacher_output = teacher(inputs)

        ema_out = self.ema_model(inputs)
        out = self.net(inputs)
        ema_model_loss = self.task_loss(ema_out, labels)

        if self.args.biam > 0:
            relevant_indices = torch.tensor(
                [idx // self.cpt == task_id_nominal for idx in range(out.size(1))],
                device=out.device)
            relevant_student = out[:, relevant_indices]
            task_loss = self.loss(relevant_student, labels % self.cpt)
        else:
            task_loss = self.task_loss(out, labels)

        ignore_mask = torch.zeros_like(labels) > 0

        if self.loss_running_mean > 0:
            sample_weight = torch.where(
                ema_model_loss >= self.args.loss_margin * self.loss_running_mean,
                self.loss_running_mean / ema_model_loss,
                torch.ones_like(ema_model_loss)
            )

            ce_loss = (sample_weight * task_loss).mean()
            ignore_mask = ema_model_loss > self.args.loss_margin * self.loss_running_mean
        else:
            ce_loss = task_loss.mean()

        loss += ce_loss


        if self.args.newl > 0:
            relevant_indices = torch.tensor(
                [idx // self.cpt == task_id_nominal for idx in range(teacher_output.size(1))],
                device=teacher_output.device)
            relevant_teacher = teacher_output[:, relevant_indices]
            relevant_student = out[:, relevant_indices]

            output_student = F.log_softmax(relevant_student / T, dim=1)
            output_teacher = F.softmax(relevant_teacher / T, dim=1)
            loss += self.args.lamb * F.kl_div(output_student, output_teacher, reduction='batchmean') * (T ** 2)

        # =====================================================================
        # Apply Buffer loss
        # =====================================================================
        if not self.buffer.is_empty():
            if self.args.lamb > 0:
                buf_inputs, buf_labels, task_labels, buf_teacher_logits  = self.buffer.get_data_old(len(labels), transform=self.transform)
            else:
                buf_inputs, buf_labels, task_labels = self.buffer.get_data_old(len(labels), transform=self.transform)

            ema_model_logits = self.ema_model(buf_inputs)

            if teacher and self.args.divk > 0:
                for t in range(task_id_nominal):
                    task_mask = task_labels == t
                    start_KD = t * self.cpt
                    end_KD = (t + 1) * self.cpt

                    prev_score = ema_model_logits[task_mask]
                    score = buf_teacher_logits[task_mask]

                    if len(score) == 0:
                        continue

                    teacher0_score_logits = prev_score[:, start_KD:end_KD]
                    teacher1_score_logits = score[:, start_KD:end_KD]

                    avg_teacher_logits = (teacher0_score_logits + teacher1_score_logits) / 2
                    ema_model_logits[task_mask, start_KD:end_KD] = avg_teacher_logits

            buf_out = self.net(buf_inputs)

            l_buf_cons = torch.mean(self.consistency_loss(buf_out, ema_model_logits.detach()))
            l_buf_ce = torch.mean(self.task_loss(buf_out, buf_labels))

            l_buf = self.args.reg_weight * l_buf_cons + l_buf_ce
            loss += l_buf

        loss.backward()
        self.opt.step()

        if labels[~ignore_mask].any() and not self.warmup_phase:

            if teacher:
                if self.args.bufs > 0:
                    _, teacher_max_indices = teacher_output.data.max(1)
                    _, outputs_max_indices = out.data.max(1)

                    clean_mask = teacher_max_indices == outputs_max_indices
                    ignore_clean_mask = ~clean_mask
                    combined_mask = ignore_clean_mask | ignore_mask
                    # ignore_clean_mask.sum(), ignore_mask.sum(), combined_mask.sum()
                    if (~combined_mask).sum() > 0:
                        self.buffer.add_data(
                            examples=not_aug_inputs[~combined_mask],
                            labels=labels[~combined_mask],
                            is_noise=noise[~combined_mask] if noise is not None else None,
                            teacher_logits=teacher_output.data[~combined_mask],
                            task_labels = torch.full((len(not_aug_inputs),), task_id_nominal)[~combined_mask]
                        )
                else:
                    self.buffer.add_data(
                        examples=not_aug_inputs[~ignore_mask],
                        labels=labels[~ignore_mask],
                        is_noise=noise[~ignore_mask] if noise is not None else None,
                        teacher_logits=teacher_output.data[~ignore_mask],
                        task_labels=torch.full((len(not_aug_inputs),), task_id_nominal)[~ignore_mask]
                    )
            else:
                self.buffer.add_data(
                    examples=not_aug_inputs[~ignore_mask],
                    labels=labels[~ignore_mask],
                    is_noise=noise[~ignore_mask] if noise is not None else None,
                    task_labels=torch.full((len(not_aug_inputs),), task_id_nominal)[~ignore_mask]
                )

        # Update the ema model
        self.global_step += 1
        self.task_iter += 1
        if torch.rand(1) < self.ema_model_update_freq:
            self.update_ema_model_variables()

        loss_mean, loss_std = ema_model_loss.mean(), ema_model_loss.std()
        ignore_mask = ema_model_loss > (loss_mean + (self.args.std_margin * loss_std))
        ema_model_loss = ema_model_loss[~ignore_mask]

        if not self.warmup_phase:
            self.update_running_loss_ema(ema_model_loss.detach())

        if hasattr(self, 'writer'):
            self.writer.add_scalar(f'Overall/loss_running_mean', self.loss_running_mean, self.global_step)
            self.writer.add_scalar(f'Overall/loss_running_std', self.loss_running_std, self.global_step)

        stats = {"mask": mask.sum().item(), "noise": noise.sum().item(),
             "i_noise": noise[mask].sum().item()} if noise is not None and mask is not None else {}
        return loss.item(), {},  stats, 0, 0

    def update_ema_model_variables(self):
        alpha = min(1 - 1 / (self.global_step + 1), self.ema_model_alpha)
        for ema_param, param in zip(self.ema_model.parameters(), self.net.parameters()):
            ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

    def update_running_loss_ema(self, batch_loss):
        alpha = min(1 - 1 / (self.global_step + 1), self.args.loss_alpha)
        self.loss_running_mean = alpha * self.loss_running_mean + (1 - alpha) * batch_loss.mean()
        self.loss_running_std = alpha * self.loss_running_std + (1 - alpha) * batch_loss.std()

    def end_task(self, dataset) -> None:
        self.current_task += 1
        self.task_iter = 0

    def begin_task(self, dataset):
        if self.args.task_warmup > 0:
            self.warmup_phase = 1
            print('Enabling Warmup phase')
        else:
            self.warmup_phase = 0

    def end_epoch(self, epoch, dataset) -> None:
        if (epoch >= self.args.task_warmup) and self.warmup_phase:
            self.warmup_phase = 0
            print('Disable Warmup phase')


# resnet reduced 1207.43896484375 MB, 1204MB 1426 1150