import logging
import random
from collections import defaultdict
from itertools import combinations
from typing import Any, Dict, Optional

import comet_ml
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch.utils.data import DataLoader
from tqdm import tqdm

from common.ours.models import MLP, discretize_actions
from common.ours.utils import PLPConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TQDM_BAR_FORMAT = "{l_bar}{bar:50}{r_bar}"


def calc_accuracy(pred_logit, gt_logit):
    with torch.no_grad():
        gt_domain_id = gt_logit.max(1).indices
        pred_domain_id = pred_logit.max(1).indices
        correct = (gt_domain_id == pred_domain_id).sum().item()
        accuracy = correct / len(gt_domain_id)

    return accuracy


def get_padded_batch_and_valid_mask(arr: torch.Tensor, target_dim: int):
    pad_size = target_dim - arr.shape[-1]
    pad_array = torch.zeros((arr.shape[0], pad_size),
                            dtype=arr.dtype).to(arr.device)
    result_array = torch.cat([arr, pad_array], dim=-1)
    result_mask = torch.cat([torch.ones_like(arr), pad_array], dim=-1)
    return result_array, result_mask


def check_for_nans(batch: Dict, loss_values: defaultdict):
    for key, value in batch.items():
        if torch.isnan(value).any():
            logger.warning(f'Found nan in batch item: {key}')
    for key, value in loss_values.items():
        if np.isnan(value).any():
            logger.warning(f'Found nan in loss item: {key}')


def is_at_most_one_more_than_zero(a, b, c, d):
    num_more_than_zero = (a > 0) + (b > 0) + (c > 0) + (d > 0)
    return num_more_than_zero <= 1


def gaussian_kernel_linear(source_z, target_z, sigma: float):
    dist = -((source_z - target_z)**2).sum(dim=-1) / (2 * sigma**2 + 1e-6)
    return dist.exp()


def l2_kernel_linear(source_z, target_z, sigma: float = 1.0):
    dist = -((source_z - target_z)**2).sum(dim=-1) / (sigma**2 + 1e-6)
    return dist


def gaussian_kernel(source_z, target_z, sigma: float):
    diff = source_z[:, None, :] - target_z[None, :, :]
    diff = -(diff**2).sum(dim=-1) / (2 * sigma**2 + 1e-6)
    return diff.exp().mean()


def laplacian_kernel(source_z, target_z, sigma: float):
    diff = source_z[:, None, :] - target_z[None, :, :]
    diff = -diff.abs().sum(dim=-1) / (sigma + 1e-6)
    return diff.exp().mean()


def l2_kernel(source_z, target_z, sigma: float):
    diff = source_z[:, None, :] - target_z[None, :, :]
    diff = -(diff**2).sum(dim=-1) / (sigma**2 + 1e-6)
    return diff.mean()


def calc_mmd_loss(source_z,
                  target_z,
                  mmd_kernel: str,
                  mmd_sigma: float,
                  normalize: bool = False,
                  linear: bool = True):
    if linear:
        assert mmd_kernel == 'gaussian'

        kernel = gaussian_kernel_linear

        n_samples = min(len(source_z), len(target_z)) // 2 * 2
        if n_samples == 0:
            return None
        source_z1, source_z2 = rearrange(source_z[:n_samples],
                                         '(n b) d -> n b d',
                                         n=2)
        target_z1, target_z2 = rearrange(target_z[:n_samples],
                                         '(n b) d -> n b d',
                                         n=2)

        if normalize:
            pair_dists_neg_sq = (l2_kernel_linear(source_z1, source_z2),
                                 l2_kernel_linear(target_z1, target_z2),
                                 l2_kernel_linear(source_z1, target_z2),
                                 l2_kernel_linear(source_z2, target_z1))
            mean_pair_distance = (-torch.cat(pair_dists_neg_sq)).sqrt().mean()
            mmd_sigma *= mean_pair_distance

        distance = kernel(source_z1, source_z2, mmd_sigma) + kernel(target_z1, target_z2, mmd_sigma) \
                   - kernel(source_z1, target_z2, mmd_sigma) - kernel(source_z2, target_z1, mmd_sigma)
        return distance.mean()

    else:
        if mmd_kernel == 'gaussian':
            kernel = gaussian_kernel
        elif mmd_kernel == 'laplacian':
            kernel = laplacian_kernel
        elif mmd_kernel == 'l2':
            kernel = l2_kernel
        else:
            raise ValueError(f'Invalid mmd_kernel: {mmd_kernel}')

        if normalize:
            all_z = torch.cat((source_z, target_z))
            all_dist_mean = (-l2_kernel(all_z, all_z, 1.0)).sqrt()
            mmd_sigma *= all_dist_mean

        source_dist = kernel(source_z, source_z, mmd_sigma)
        target_dist = kernel(target_z, target_z, mmd_sigma)
        cross_dist = kernel(source_z, target_z, mmd_sigma)

        return source_dist + target_dist - 2 * cross_dist


def calc_hausdorff_loss(source_z,
                        target_z,
                        soft=False,
                        normalize: bool = False):

    def _max_min_dist(from_z, to_z):
        dist = ((from_z[:, None, :] - to_z[None, :, :])**2).sum(dim=-1)
        if normalize:
            dist /= dist.mean()
        if soft:
            weight_for_min = (-dist).softmax(
                dim=-1)  # smaller dist -> larger weight
            dist = (dist * weight_for_min).sum(dim=-1)
            weight_for_max = dist.softmax(
                dim=-1)  # larger dist -> larger weight
            dist = (dist * weight_for_max).sum()
        else:
            dist = dist.min(dim=-1).values.max()
        return dist

    max_min_dist_from_source = _max_min_dist(from_z=source_z, to_z=target_z)
    max_min_dist_from_target = _max_min_dist(from_z=target_z, to_z=source_z)

    if soft:
        dist_list = torch.cat(
            [max_min_dist_from_source[None], max_min_dist_from_target[None]])
        weight = dist_list.softmax(dim=0)
        return (dist_list * weight).sum()
    else:
        return torch.max(max_min_dist_from_source, max_min_dist_from_target)


def calc_distreg_loss(source_z, target_z):

    def _calc_mean_dist(from_z, to_z):
        dist = ((from_z[:, None, :] - to_z[None, :, :])**2).sum(dim=-1)
        dist = dist.min(dim=-1).values.mean()
        return dist

    mean_dist_from_source = _calc_mean_dist(from_z=source_z, to_z=target_z)
    mean_dist_from_target = _calc_mean_dist(from_z=target_z, to_z=source_z)

    return (mean_dist_from_source + mean_dist_from_target) / 2


def calc_bc_loss(pred_actions: torch.FloatTensor,
                 actions: torch.FloatTensor,
                 action_masks: torch.FloatTensor,
                 discrete: bool = False,
                 discrete_bins: int = 10):
    if discrete:
        pred_actions = rearrange(pred_actions, 'b n d -> b d n')
        actions = discretize_actions(actions, n_bins=discrete_bins)

        bc_loss = torch.nn.functional.cross_entropy(
            pred_actions, actions, reduction='none').mean(dim=-1)
    else:
        bc_loss = (pred_actions * action_masks -
                   actions * action_masks).pow(2).mean(dim=-1)

    return bc_loss.mean()


def calc_loss_and_learn(
    args: PLPConfig,
    batch,
    model_dict: Optional[Dict[str, torch.nn.Module]],
    optimizer_dict: Optional[Dict[str, torch.optim.Optimizer]],
    scaler: torch.cuda.amp.GradScaler,
    tcc_batch: Optional[Any] = None,
):
    """Core function of the algorithm.

    batch: (states, conds(one-hot task_ids), domain_ids (2D), actions)
    Returns: metrics dictionary.
    """
    policy = model_dict.get('policy')
    assert policy is not None
    assert is_at_most_one_more_than_zero(args.adversarial_coef, args.mmd_coef,
                                         args.hausdorff_coef,
                                         args.distreg_coef)

    states = batch['observations'].to(args.device)
    next_states = batch['next_observations'].to(args.device)
    conds = batch['task_ids'].to(args.device)
    domain_ids = batch['domain_ids'].to(args.device)
    actions = batch['actions'].to(args.device)
    action_masks = batch['action_masks'].to(args.device)
    if args.image_observation:
        images = batch['images'].to(args.device)
    if args.image_observation and args.state_pred:
        next_images = batch['next_images'].to(args.device)

    discriminator = model_dict.get('discriminator')

    optimizer = optimizer_dict.get('policy')
    optimizer_disc = optimizer_dict.get('discriminator')

    source_flag = domain_ids[..., -1] < 0.5

    if args.state_noise:
        states += torch.randn_like(states) * args.state_noise

    if args.action_noise:
        actions += torch.randn_like(actions) * args.action_noise

    with torch.cuda.amp.autocast(enabled=args.amp):
        if args.image_observation:
            images_encoded = policy.image_encoder(images.float())
            states[..., -args.image_state_dim:] = images_encoded
        if args.image_observation and args.state_pred:
            next_images_encoded = policy.image_encoder(next_images.float())
            next_states[..., -args.image_state_dim:] = next_images_encoded

        pred_actions, z, z_alpha = policy(s=states, c=conds, d=domain_ids)
        policy_loss = 0
        metrics_dict = {}

        if discriminator and z is not None:
            cross_entropy = torch.nn.CrossEntropyLoss()
            pred_domain_logit = discriminator(z=z.detach(), c=conds)

            disc_loss = cross_entropy(pred_domain_logit, domain_ids)

            metrics_dict['accuracy'] = calc_accuracy(
                pred_logit=pred_domain_logit, gt_logit=domain_ids)

            if optimizer_disc:
                optimizer_disc.zero_grad()
                scaler.scale(disc_loss).backward()
                scaler.step(optimizer_disc)
                scaler.update()

                pred_domain_logit = discriminator(z=z, c=conds)
                adv_loss = -cross_entropy(pred_domain_logit, domain_ids)
                policy_loss += adv_loss * args.adversarial_coef

            metrics_dict['disc_loss'] = disc_loss.item()

        if args.state_pred:
            assert args.decode_with_state, 'state_pred requires decode_with_state'
            # Next state prediction
            policy_inp = torch.cat((next_states, domain_ids), dim=-1)
            next_z = policy.encoder(policy_inp).detach()
            next_state_reg_loss = (next_z - z_alpha).pow(2).mean()
            policy_loss += next_state_reg_loss * args.next_state_reg_coef
            metrics_dict['next_state_reg_loss'] = next_state_reg_loss.item()

            # IDM loss
            policy_inp = torch.cat((next_states, domain_ids), dim=-1)
            next_z = policy.encoder(policy_inp).detach()
            decoder_inp = torch.cat((next_z, domain_ids, states), dim=-1)
            pred_actions = policy.head(decoder_inp)

            if (~source_flag).sum() > 0:
                idm_loss = calc_bc_loss(
                    pred_actions=pred_actions[~source_flag],
                    actions=actions[~source_flag],
                    action_masks=action_masks[~source_flag])
                policy_loss += idm_loss * args.idm_coef
                metrics_dict['idm_loss'] = idm_loss.item()
        else:
            bc_loss = calc_bc_loss(pred_actions=pred_actions,
                                   actions=actions,
                                   action_masks=action_masks,
                                   discrete=args.discrete,
                                   discrete_bins=args.discrete_bins).mean()
            metrics_dict['bc_loss'] = bc_loss.item()
            policy_loss += bc_loss if args.bc else torch.zeros(1, requires_grad=True).to(
                args.device)

        if args.image_observation and args.use_image_decoder:
            images_recon = policy.image_decoder(images_encoded)
            images_scaled = (images.float() / 255.).clamp(0, 1)
            image_recon_loss = nn.MSELoss()(images_recon, images_scaled)
            policy_loss += image_recon_loss * args.image_recon_coef
            metrics_dict['image_recon_loss'] = image_recon_loss.item()

        if z is not None:
            metrics_dict['repr_norm'] = z.norm(dim=-1).mean().item()

        if args.norm_hinge_coef > 0 and z is not None:
            # if norm is larger than 1.0, then the loss is 0.
            # if this loss is ignored, we can try relu(1 / z.norm() - 1.0)
            hinge_loss = torch.relu(1.0 - z.norm(dim=-1)).mean()
            policy_loss += hinge_loss * args.norm_hinge_coef
            metrics_dict['norm_hinge_loss'] = hinge_loss.item()

        if args.distreg_coef > 0 and z is not None:
            n_traj_of_each_domain = (domain_ids > 0.5).sum(dim=0) > 0
            if n_traj_of_each_domain.all():
                source_z = z[source_flag]
                target_z = z[~source_flag]

                distreg_loss = calc_distreg_loss(source_z, target_z)
                metrics_dict['distreg_loss'] = distreg_loss.item()
                policy_loss += distreg_loss * args.distreg_coef

        if args.mmd_coef > 0 and z is not None:
            n_traj_of_each_domain = (domain_ids > 0.5).sum(dim=0) > 0
            if n_traj_of_each_domain.all():
                domain_lists = list(range(domain_ids.shape[-1]))
                mmd_losses = []
                for domain1, domain2 in combinations(domain_lists, 2):
                    source_z = z[domain_ids[..., domain1] > 0.5]
                    target_z = z[domain_ids[..., domain2] > 0.5]
                    mmd_loss = calc_mmd_loss(source_z,
                                             target_z,
                                             mmd_kernel=args.mmd_kernel,
                                             mmd_sigma=args.mmd_sigma,
                                             normalize=args.mmd_norm,
                                             linear=args.mmd_linear)
                    if mmd_loss is not None:
                        mmd_losses.append(mmd_loss)

                if len(mmd_losses) > 0:
                    mmd_loss = torch.stack(mmd_losses).mean()
                    metrics_dict['mmd_loss'] = mmd_loss.item()
                    policy_loss += mmd_loss * args.mmd_coef

        if args.hausdorff_coef > 0 and z is not None:
            n_traj_of_each_domain = (domain_ids > 0.5).sum(dim=0) > 0
            if n_traj_of_each_domain.all():
                source_z = z[source_flag]
                target_z = z[~source_flag]
                hausdorff_loss = calc_hausdorff_loss(
                    source_z,
                    target_z,
                    soft=args.hausdorff_soft,
                    normalize=args.hausdorff_norm)

                metrics_dict['hausdorff_loss'] = hausdorff_loss.item()
                policy_loss += hausdorff_loss * args.hausdorff_coef

        if tcc_batch is not None and policy.encoder is not None:
            tcc_loss = tcc_process_batch(
                args,
                batch=tcc_batch,
                encoder=policy.encoder,
                image_encoder=policy.image_encoder
                if args.image_observation else None,
            )
            metrics_dict['tcc_loss'] = tcc_loss.item()
            policy_loss += tcc_loss * args.tcc_coef

        check_for_nans(batch, metrics_dict)

    if optimizer:
        optimizer.zero_grad()
        scaler.scale(policy_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    metrics_dict['policy_loss'] = policy_loss.item()

    return metrics_dict


# for tcc
def _get_weight_vector(
    query: torch.Tensor,
    key: torch.Tensor,
    key_mask: torch.Tensor,
    temperature: float = 1.0,
    return_logit: bool = False,
):
    dim = query.shape[-1]
    query = query.unsqueeze(-2)  # [batch, query, 1, dim]
    key = key.unsqueeze(-3)  # [batch, 1, key, dim]
    logit = -((query - key)**2).sum(axis=-1) / dim / temperature
    logit = logit + key_mask.unsqueeze(
        -2)  # key_mask.unsqueeze = [batch, 1, dim]
    if return_logit:
        return logit
    weight = logit.softmax(axis=-1)  # sum(last dim) = 1
    return weight


# for tcc
def _get_domain_ids(obs: torch.Tensor,
                    domain_id: int,
                    n_domain_id: int = 2) -> torch.Tensor:
    domain_ids = torch.zeros((*obs.shape[:2], n_domain_id),
                             dtype=torch.float32,
                             device=obs.device)
    domain_ids[..., domain_id] = 1.
    return domain_ids


def tcc_process_batch(
        args,
        batch,
        encoder: MLP,
        image_encoder: Optional[
            nn.Module] = None,  # adapter for converting latents
):
    s_obs, t_obs, s_pad_mask, t_pad_mask \
        = batch['obs1'], batch['obs2'], batch['seq_pad_masks1'], batch['seq_pad_masks2']

    def _pad_state_to_size(states: torch.Tensor, size: int) -> torch.Tensor:
        if states.shape[-1] < size:
            s_pad = torch.zeros((*states.shape[:2], size - states.shape[-1]),
                                device=states.device,
                                dtype=states.dtype)
            states = torch.cat((states, s_pad), dim=-1)
        return states

    enc_in_state_dim = args.max_obs_dim
    s_obs = s_obs.to(args.device)
    s_obs = _pad_state_to_size(states=s_obs, size=enc_in_state_dim)
    t_obs = t_obs.to(args.device)
    t_obs = _pad_state_to_size(states=t_obs, size=enc_in_state_dim)

    s_pad_mask = s_pad_mask.to(args.device)
    t_pad_mask = t_pad_mask.to(args.device)

    s_idx = torch.arange(s_pad_mask.shape[-1], device=s_pad_mask.device)
    t_idx = torch.arange(t_pad_mask.shape[-1], device=t_pad_mask.device)
    s_selected_idx = s_idx % args.tcc_frame_interval == 0
    t_selected_idx = t_idx % args.tcc_frame_interval == 0
    s_obs = s_obs[:, s_selected_idx]
    t_obs = t_obs[:, t_selected_idx]
    s_pad_mask = s_pad_mask[:, s_selected_idx]
    t_pad_mask = t_pad_mask[:, t_selected_idx]

    if args.image_observation:
        assert image_encoder is not None
        s_image = batch['image1'].to(args.device)[:, s_selected_idx]
        s_image_shape = s_image.shape
        s_image_encoded = image_encoder(
            s_image.reshape(-1, *s_image_shape[2:])).reshape(
                *s_image_shape[:2], -1)
        s_obs[..., -args.image_state_dim:] = s_image_encoded

        t_image = batch['image2'].to(args.device)[:, t_selected_idx]
        t_image_shape = t_image.shape
        t_image_encoded = image_encoder(
            t_image.reshape(-1, *t_image_shape[2:])).reshape(
                *t_image_shape[:2], -1)
        t_obs[..., -args.image_state_dim:] = t_image_encoded

    s_domain_ids = batch['domain_ids1'].to(args.device)[:, s_selected_idx]
    t_domain_ids = batch['domain_ids2'].to(args.device)[:, t_selected_idx]

    s_inp = torch.concat((s_obs, s_domain_ids), dim=-1)
    t_inp = torch.concat((t_obs, t_domain_ids), dim=-1)

    seq_len = s_inp.shape[-2]

    s_latent = encoder(s_inp)
    t_latent = encoder(t_inp)
    s_mask_for_softmax = torch.zeros_like(s_pad_mask, dtype=torch.float32)
    s_mask_for_softmax[s_pad_mask] = -1e9

    t_mask_for_softmax = torch.zeros_like(t_pad_mask, dtype=torch.float32)
    t_mask_for_softmax[t_pad_mask] = -1e9

    alpha = _get_weight_vector(query=s_latent,
                               key=t_latent,
                               key_mask=t_mask_for_softmax)
    retrieved_vector = torch.bmm(alpha, t_latent)

    beta = _get_weight_vector(
        query=retrieved_vector,
        key=s_latent,
        key_mask=s_mask_for_softmax,
        return_logit=True,
    )
    idx_array = torch.arange(s_latent.shape[-2],
                             dtype=torch.long,
                             device=s_latent.device)

    valid_prediction = beta[~s_pad_mask]
    valid_ans = idx_array.repeat((beta.shape[0], 1))[~s_pad_mask]
    tcc_loss = torch.nn.CrossEntropyLoss()(valid_prediction, valid_ans)

    return tcc_loss


def train(
    args,
    dataloader: DataLoader,  # StepLoader
    model_dict: Dict[str, torch.nn.Module],
    optimizer_dict: Dict[str, torch.optim.Optimizer],
    paired_dataloader: Optional[DataLoader] = None,
):
    train_metrics_dict = defaultdict(list)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for batch in tqdm(dataloader,
                      bar_format=TQDM_BAR_FORMAT,
                      total=len(dataloader)):
        use_tcc = paired_dataloader is not None and random.random(
        ) < args.tcc_prob
        tcc_batch = next(iter(paired_dataloader)) if use_tcc else None

        metrics_dict = calc_loss_and_learn(
            args,
            batch=batch,
            model_dict=model_dict,
            optimizer_dict=optimizer_dict,
            tcc_batch=tcc_batch,
            scaler=scaler,
        )

        for key, val in metrics_dict.items():
            train_metrics_dict[key].append(val)

    return train_metrics_dict


def validation(
    args,
    dataloader: DataLoader,
    model_dict: Dict[str, torch.nn.Module],
    paired_dataloader: Optional[DataLoader] = None,
):
    val_metrics_dict = defaultdict(list)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for batch in tqdm(dataloader,
                      bar_format=TQDM_BAR_FORMAT,
                      total=len(dataloader)):
        use_tcc = paired_dataloader is not None and random.random(
        ) < args.tcc_prob
        tcc_batch = next(iter(paired_dataloader)) if use_tcc else None
        metrics_dict = calc_loss_and_learn(
            args,
            batch=batch,
            model_dict=model_dict,
            optimizer_dict={},
            tcc_batch=tcc_batch,
            scaler=scaler,
        )

        for key, val in metrics_dict.items():
            val_metrics_dict[key].append(val)

    return val_metrics_dict


def process_metrics_array_dict(
        metrics_array_dict: dict,
        context: str,
        epoch: int,
        experiment: Optional[comet_ml.Experiment] = None):
    msg = ''
    metrics_dict = {}
    for name, val in metrics_array_dict.items():
        mean_val = np.array(val).mean()
        msg += f'{name}: {mean_val:.5f} '
        metrics_dict[name] = mean_val

    if experiment:
        experiment.log_metrics(
            metrics_dict,
            prefix=context,
            epoch=epoch,
        )

    logger.info(f'Epoch: {epoch} {context} {msg}')
    return metrics_dict


def epoch_loop(
    args,
    model_dict: Dict[str, torch.nn.Module],
    optimizer_dict: Dict[str, torch.optim.Optimizer],
    dataloader_dict: Dict[str, DataLoader],
    epoch: int,
    experiment: Optional[comet_ml.Experiment],
    paired_dataloader_dict: Optional[Dict[str, DataLoader]] = None,
    log_prefix: str = '',
):
    """Run a single epoch.

    Args:
        args: []
        model_dict: 'policy': Policy (required), 'discriminator': Discriminator,
            'recon_decoder': ReconstructionDecoder
        optimizer_dict: 'policy': Optimizer, 'discriminator': Optimizer, 'recon_decoder': Optimizer
        dataloader_dict: {'train': Dict[str, DataLoader], 'validation': Dict[str, DataLoader]}
        epoch: epoch_no (0-?)
        experiment: comet
        paired_dataloader_dict: {'train': DataLoader, 'validation': DataLoader}
        log_prefix: 'align', 'adapt', for example.
            Metric key will be like train_`log_prefix`_bc_loss

    Returns:

    """

    if paired_dataloader_dict is None:
        paired_dataloader_dict = {}

    train_metrics_dict = train(
        args,
        model_dict=model_dict,
        optimizer_dict=optimizer_dict,
        dataloader=dataloader_dict['train'],
        paired_dataloader=paired_dataloader_dict.get('train'),
    )
    val_metrics_dict = validation(
        args,
        model_dict=model_dict,
        dataloader=dataloader_dict['validation'],
        paired_dataloader=paired_dataloader_dict.get('validation'),
    )

    context = 'train'
    if log_prefix:
        context += f'_{log_prefix}'
    _ = process_metrics_array_dict(train_metrics_dict,
                                   context=context,
                                   epoch=epoch,
                                   experiment=experiment)

    context = 'valid'
    if log_prefix:
        context += f'_{log_prefix}'
    val_metrics_dict = process_metrics_array_dict(val_metrics_dict,
                                                  context=context,
                                                  epoch=epoch,
                                                  experiment=experiment)

    val_loss = val_metrics_dict['policy_loss']

    return val_loss
