import logging
from collections import defaultdict
from typing import Dict, Optional

import comet_ml
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def calc_accuracy(pred_logit, gt_logit):
    with torch.no_grad():
        gt_domain_id = gt_logit.max(1).indices
        pred_domain_id = pred_logit.max(1).indices
        correct = (gt_domain_id == pred_domain_id).sum().item()
        accuracy = correct / len(gt_domain_id)

    return accuracy


def calc_recon_align_score(encoder: torch.nn.Module,
                           recon_decoder: torch.nn.Module, states, domain_ids):
    source_flag = domain_ids[..., 0] > 0.5
    gt_source_states = states[source_flag]
    inp_target_states = gt_source_states[..., [1, 0, 3, 2]]
    with torch.no_grad():
        target_domain_ids = torch.tensor([0, 1],
                                         dtype=domain_ids.dtype,
                                         device=domain_ids.device).tile(
                                             (len(gt_source_states), 1))
        inp = torch.cat((inp_target_states, target_domain_ids), dim=-1)
        z = encoder(inp)
        pred = recon_decoder(z)
        loss = nn.MSELoss()(pred, gt_source_states)
    return loss


def calc_recon_decoder_loss(encoder: torch.nn.Module,
                            recon_decoder: torch.nn.Module, states,
                            domain_ids):
    source_flag = domain_ids[..., 0] > 0.5
    source_states = states[source_flag]
    with torch.no_grad():
        inp_domain = domain_ids[source_flag]
        inp = torch.cat((source_states, inp_domain), dim=-1)
        z = encoder(inp)

    pred = recon_decoder(z.detach())
    loss = nn.MSELoss()(pred, source_states)

    return loss


def calc_loss_and_learn(
    args,
    batch,
    model_dict: Optional[Dict[str, torch.nn.Module]],
    optimizer_dict: Optional[Dict[str, torch.optim.Optimizer]],
):
    """Core function of the algorithm.

    batch: (states, conds(one-hot task_ids), domain_ids (2D), actions, next_states, action_masks)
    Returns: metrics dictionary.
    """
    states, conds, domain_ids, actions, _, action_masks = batch
    states = states.to(args.device)
    conds = conds.to(args.device)
    domain_ids = domain_ids.to(args.device)
    actions = actions.to(args.device)
    action_masks = action_masks.to(args.device)
    target_flag = domain_ids[..., 1] > 0.5

    policy = model_dict.get('policy')
    assert policy is not None
    discriminator = model_dict.get('discriminator')
    recon_decoder = model_dict.get('recon_decoder')

    optimizer = optimizer_dict.get('policy')
    optimizer_disc = optimizer_dict.get('discriminator')
    opt_recon_decoder = optimizer_dict.get('recon_decoder')

    loss_func = torch.nn.MSELoss()

    pred_actions, z, z_alpha = policy(s=states, c=conds, d=domain_ids)

    cross_entropy = torch.nn.CrossEntropyLoss()

    metrics_dict = {}

    if hasattr(args, 'target_coef') and hasattr(args, 'source_coef'):
        masked_error = (pred_actions - actions) * action_masks

        source_loss = torch.mean(masked_error[~target_flag]**2)
        target_loss = torch.mean(masked_error[target_flag]**2)
        if torch.isnan(source_loss):
            source_loss = torch.tensor(0).to(args.device)
        if torch.isnan(target_loss):
            target_loss = torch.tensor(0).to(args.device)
        policy_loss = (args.source_coef * source_loss +
                       args.target_coef * target_loss) / 2

        metrics_dict['source_loss'] = source_loss.item()
        metrics_dict['target_loss'] = target_loss.item()

    else:
        policy_loss = loss_func(pred_actions * action_masks,
                                actions * action_masks)

    metrics_dict['bc_loss'] = policy_loss.item()

    # L2 reguralization
    l2 = torch.tensor(0., requires_grad=True)
    for w in policy.encoder.parameters():
        l2 = l2 + torch.norm(w)**2

    if hasattr(args, 'enc_decay') and args.enc_decay > 0:
        policy_loss = policy_loss + args.enc_decay * l2

    if hasattr(args, 'calc_align_score') and args.calc_align_score:
        if recon_decoder:
            recon_loss = calc_recon_decoder_loss(
                encoder=policy.encoder,
                states=states,
                domain_ids=domain_ids,
                recon_decoder=recon_decoder,
            )
            metrics_dict['recon_loss'] = recon_loss.item()

            if opt_recon_decoder:
                opt_recon_decoder.zero_grad()
                recon_loss.backward()
                opt_recon_decoder.step()

            recon_align_score = calc_recon_align_score(
                encoder=policy.encoder,
                states=states,
                domain_ids=domain_ids,
                recon_decoder=recon_decoder,
            )
            metrics_dict['recon_align_score'] = recon_align_score.item()

    if discriminator:
        pred_domain_logit = discriminator(z=z.detach(),
                                          z_alpha=z_alpha.detach(),
                                          c=conds)

        disc_loss = cross_entropy(pred_domain_logit, domain_ids)

        metrics_dict['accuracy'] = calc_accuracy(pred_logit=pred_domain_logit,
                                                 gt_logit=domain_ids)

        if optimizer_disc:
            optimizer_disc.zero_grad()
            disc_loss.backward()
            optimizer_disc.step()

        pred_domain_logit = discriminator(z=z, z_alpha=z_alpha, c=conds)
        adv_loss = -cross_entropy(pred_domain_logit, domain_ids)
        policy_loss += adv_loss * args.adversarial_coef

        metrics_dict['disc_loss'] = disc_loss.item()

    if optimizer:
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

    metrics_dict['policy_loss'] = policy_loss.item()

    return metrics_dict


def train(
    args,
    dataloader_dict: Dict[str, DataLoader],
    model_dict: Dict[str, torch.nn.Module],
    optimizer_dict: Dict[str, torch.optim.Optimizer],
):
    train_metrics_dict = defaultdict(list)

    if hasattr(args, "verbose") and args.verbose:
        from tqdm import tqdm
        loader = tqdm(dataloader_dict["train"],
                      bar_format="{l_bar}{bar:50}{r_bar}")
    else:
        loader = dataloader_dict["train"]

    for batch in loader:
        metrics_dict = calc_loss_and_learn(
            args,
            batch=batch,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
        )

        for key, val in metrics_dict.items():
            train_metrics_dict[key].append(val)

    return train_metrics_dict


def validation(
    args,
    dataloader_dict: Dict[str, DataLoader],
    model_dict: Dict[str, torch.nn.Module],
):
    val_metrics_dict = defaultdict(list)

    if hasattr(args, "verbose") and args.verbose:
        from tqdm import tqdm
        loader = tqdm(dataloader_dict["validation"],
                      bar_format="{l_bar}{bar:50}{r_bar}")
    else:
        loader = dataloader_dict["validation"]

    for batch in loader:
        with torch.no_grad():
            metrics_dict = calc_loss_and_learn(
                args,
                batch=batch,
                model_dict=model_dict,
                optimizer_dict={},
            )

        for key, val in metrics_dict.items():
            val_metrics_dict[key].append(val)

    return val_metrics_dict


def process_metrics_array_dict(
        metrics_array_dict: dict,
        context: str,
        epoch: int,
        experiment: Optional[comet_ml.Experiment] = None):
    msg = ''
    metrics_dict = {}
    for name, val in metrics_array_dict.items():
        mean_val = np.array(val).mean()
        msg += f'{name}: {mean_val:.5f} '
        metrics_dict[name] = mean_val

    if experiment:
        experiment.log_metrics(
            metrics_dict,
            prefix=context,
            epoch=epoch,
        )

    logger.info(f'Epoch: {epoch} {context} {msg}')
    return metrics_dict


def epoch_loop(
    args,
    model_dict: Dict[str, torch.nn.Module],
    optimizer_dict: Dict[str, torch.optim.Optimizer],
    dataloader_dict: Dict[str, DataLoader],
    epoch: int,
    experiment: Optional[comet_ml.Experiment],
    log_prefix: str = '',
):
    """Run a single epoch.

    Args:
        args: []
        model_dict: 'policy': Policy (required), 'discriminator': Discriminator,
            'recon_decoder': ReconstructionDecoder
        optimizer_dict: 'policy': Optimizer, 'discriminator': Optimizer, 'recon_decoder': Optimizer
        dataloader_dict: 'train': DataLoader, 'validation': DataLoader
        epoch: epoch_no (0-?)
        experiment: comet
        log_prefix: 'align', 'adapt', for example.
            Metric key will be like train_`log_prefix`_bc_loss

    Returns:

    """

    train_metrics_dict = train(
        args,
        model_dict=model_dict,
        optimizer_dict=optimizer_dict,
        dataloader_dict=dataloader_dict,
    )
    val_metrics_dict = validation(
        args,
        model_dict=model_dict,
        dataloader_dict=dataloader_dict,
    )

    context = 'train'
    if log_prefix:
        context += f'_{log_prefix}'
    _ = process_metrics_array_dict(train_metrics_dict,
                                   context=context,
                                   epoch=epoch,
                                   experiment=experiment)

    context = 'valid'
    if log_prefix:
        context += f'_{log_prefix}'
    val_metrics_dict = process_metrics_array_dict(val_metrics_dict,
                                                  context=context,
                                                  epoch=epoch,
                                                  experiment=experiment)

    val_loss = val_metrics_dict['policy_loss']

    return val_loss
