import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Callable, List, Optional

import comet_ml
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

from common.dail.models import DAILAgent
from common.dail.utils import (calc_accuracy, calc_alignment_score,
                               logit_bernoulli_entropy, plot_alignment,
                               sigmoid_cross_entropy_with_logits)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TQDM_BAR_FORMAT = "{l_bar}{bar:50}{r_bar}"


def train_source_policy(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train source domain policy with BC
    # ----------------------------------------
    print("Start training expert policy...")
    # TODO load from source_only_dataset (with correct batch_size)
    parameters = list(agent.source_policy.parameters())
    if args.image_observation:
        parameters += list(agent.source_image_encoder.parameters())
        if args.use_image_decoder:
            parameters += list(agent.source_image_decoder.parameters())
    optimizer = torch.optim.Adam(
        parameters,
        lr=omega_args.bc.lr,
        betas=[0.9, 0.999],
    )
    mse = nn.MSELoss()
    for epoch in range(1, epochs + 1):
        epoch_metrics_dict = defaultdict(list)
        for batch in tqdm(train_loader, bar_format=TQDM_BAR_FORMAT):
            obs = batch["observations"]
            act = batch["actions"]
            task_ids = batch["task_ids"]
            domain_ids = batch["domain_ids"]

            # select source domain data
            source_flag = domain_ids[..., 0] > 0.5
            obs = obs[source_flag][..., :args.source_state_dim].to(args.device)
            task_ids = task_ids[source_flag].to(args.device)
            act = act[source_flag][..., :args.source_action_dim]

            if args.image_observation:
                images = batch["images"][source_flag].float().to(args.device)
                images_encoded = agent.source_image_encoder(images)
                obs = torch.concat(
                    (obs[..., :-args.image_state_dim], images_encoded), dim=-1)

            obs = torch.cat((obs, task_ids), dim=-1)

            act_pred = agent.source_policy(obs.to(args.device))
            loss = mse(act.to(args.device), act_pred)

            epoch_metrics_dict["loss"].append(loss.detach().cpu().numpy())

            if args.image_observation and args.use_image_decoder:
                images_recon = agent.source_image_decoder(images_encoded)
                images_scaled = (images.float() / 255.).clamp(0, 1)
                image_recon_loss = mse(images_recon, images_scaled)
                epoch_metrics_dict['image_recon_loss'].append(
                    image_recon_loss.item())
                loss += image_recon_loss * args.image_recon_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for metric in epoch_metrics_dict.keys():
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s = f"Epoch {epoch:3d}/{epochs} | "
        s += f"Loss: {epoch_metrics_dict['loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metric(name='source_bc_loss',
                                  value=epoch_metrics_dict['loss'],
                                  epoch=epoch)
            if args.image_observation and args.use_image_decoder:
                experiment.log_metric(
                    name='source_image_recon_loss',
                    value=epoch_metrics_dict['image_recon_loss'],
                    epoch=epoch)

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')


def train_target_dynamics_model(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train target domain dynamics model
    # ----------------------------------------

    print("Start training dynamics model...")
    # TODO load from target_only_dataset (with correct batch_size)
    parameters = list(agent.dynamics_model.parameters())
    lr = omega_args.models.dynamics_model.lr
    if args.image_observation:
        parameters += list(agent.target_image_encoder.parameters())
        if args.use_image_decoder:
            parameters += list(agent.target_image_decoder.parameters())

    optimizer = torch.optim.Adam(
        parameters,
        lr=lr,
        betas=[0.9, 0.999],
    )
    mse = nn.MSELoss()
    for epoch in range(1, epochs + 1):
        s = f"Epoch {epoch:3d}/{epochs} | "
        epoch_metrics_dict = defaultdict(list)
        for batch in tqdm(train_loader, bar_format=TQDM_BAR_FORMAT):
            obs = batch["observations"]
            act = batch["actions"]
            next_obs = batch["next_observations"]
            domain_ids = batch["domain_ids"]

            # select source domain data
            target_flag = domain_ids[..., 1] > 0.5
            obs = obs[target_flag][..., :args.target_state_dim].to(args.device)
            act = act[target_flag][..., :args.target_action_dim].to(
                args.device)
            next_obs = next_obs[target_flag][..., :args.target_state_dim].to(
                args.device)

            if args.image_observation:
                images = batch["images"][target_flag].float().to(args.device)
                images_encoded = agent.target_image_encoder(images)
                obs = torch.concat(
                    (obs[..., :-args.image_state_dim], images_encoded), dim=-1)

                next_images = batch["next_images"][target_flag].float().to(
                    args.device)
                next_images_encoded = agent.target_image_encoder(next_images)
                # Do not backpropagate the gradients for next_images_encoded
                next_images_encoded = next_images_encoded.detach()
                next_obs = torch.concat((next_obs[..., :-args.image_state_dim],
                                         next_images_encoded),
                                        dim=-1)

            obs_act = torch.cat((obs, act), dim=-1)

            next_obs_pred = agent.dynamics_model(obs_act)
            loss = mse(next_obs_pred, next_obs)

            epoch_metrics_dict["loss"].append(loss.detach().cpu().numpy())

            if args.image_observation and args.use_image_decoder:
                images_recon = agent.target_image_decoder(images_encoded)
                images_scaled = (images.float() / 255.).clamp(0, 1)
                image_recon_loss = mse(images_recon, images_scaled)
                epoch_metrics_dict['image_recon_loss'].append(
                    image_recon_loss.item())
                loss += image_recon_loss * args.image_recon_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for metric in epoch_metrics_dict.keys():
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s += f"Loss: {epoch_metrics_dict['loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metric(name='dyn_loss',
                                  value=epoch_metrics_dict['loss'],
                                  epoch=epoch)
            if args.image_observation and args.use_image_decoder:
                experiment.log_metric(
                    name='target_image_recon_loss',
                    value=epoch_metrics_dict['image_recon_loss'],
                    epoch=epoch)

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')


def train_gama(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    val_loader: Optional[DataLoader] = None,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train MDP Alignment with GAMA
    # ----------------------------------------

    print("Start training MDP Alignment...")
    calc_align_score = hasattr(args,
                               'calc_align_score') and args.calc_align_score

    optimizer_dict = {
        "discriminator":
        torch.optim.Adam(agent.discriminator.parameters(),
                         lr=omega_args.models.discriminator.lr,
                         betas=[0.9, 0.999]),
        "state_map":
        torch.optim.Adam(agent.state_map.parameters(),
                         lr=omega_args.models.state_map.lr,
                         betas=[0.9, 0.999]),
        "action_map":
        torch.optim.Adam(agent.action_map.parameters(),
                         lr=omega_args.models.action_map.lr,
                         betas=[0.9, 0.999]),
    }
    metrics = [
        "discriminator_loss",
        "discriminator_accuracy",
        "adversarial_loss",
        "bc_loss",
        'alignment_score',
        "total_loss",
    ]
    mse = nn.MSELoss()

    for epoch in range(1, epochs + 1):
        s = f"Epoch {epoch:3d}/{epochs} | "
        epoch_metrics_dict = {metric: [] for metric in metrics}
        true_source_states = []
        predicted_source_states = []
        for batch in tqdm(
                train_loader,
                bar_format=TQDM_BAR_FORMAT,
        ):
            obs = batch["observations"]
            act = batch["actions"]
            next_obs = batch["next_observations"]
            task_ids = batch["task_ids"]
            domain_ids = batch["domain_ids"]

            source_flag = domain_ids[..., 0] > 0.5
            source_obs = obs[source_flag][..., :args.source_state_dim]
            source_act = act[source_flag][..., :args.source_action_dim]
            source_task_ids = task_ids[source_flag]
            source_domain_ids = domain_ids[source_flag]
            source_next_obs = next_obs[source_flag][
                ..., :args.source_state_dim]

            target_obs = obs[~source_flag][..., :args.target_state_dim]
            target_act = act[~source_flag][..., :args.target_action_dim]
            target_task_ids = task_ids[~source_flag]
            target_domain_ids = domain_ids[~source_flag]
            target_next_obs = next_obs[~source_flag][
                ..., :args.target_state_dim]

            if args.image_observation:
                source_images = batch["images"][source_flag].float()
                source_images_encoded = agent.source_image_encoder(
                    source_images.to(args.device))
                source_images_encoded = source_images_encoded.detach().cpu()
                source_next_images = batch["next_images"][source_flag].float()
                source_next_images_encoded = agent.source_image_encoder(
                    source_next_images.to(args.device))
                source_next_images_encoded = source_next_images_encoded.detach(
                ).cpu()
                source_obs = torch.concat(
                    (source_obs[..., :-args.image_state_dim],
                     source_images_encoded),
                    dim=-1)
                source_next_obs = torch.concat(
                    (source_next_obs[..., :-args.image_state_dim],
                     source_next_images_encoded),
                    dim=-1)

                target_images = batch["images"][~source_flag].float()
                target_images_encoded = agent.target_image_encoder(
                    target_images.to(args.device))
                target_images_encoded = target_images_encoded.detach().cpu()
                target_obs = torch.concat(
                    (target_obs[..., :-args.image_state_dim],
                     target_images_encoded),
                    dim=-1)

            # Discriminator ----------
            source_obs_hat = agent.state_map(target_obs.to(args.device))
            source_act_hat = agent.source_policy(
                torch.cat((source_obs_hat, target_task_ids.to(args.device)),
                          dim=-1))
            if agent.decode_with_state:
                source_act_hat_inp = torch.cat(
                    (source_act_hat, target_obs.to(args.device)), dim=-1)
            else:
                source_act_hat_inp = source_act_hat
            target_act_hat = agent.action_map(source_act_hat_inp)
            source_next_obs_hat = agent.state_map(
                agent.dynamics_model(
                    torch.cat((target_obs.to(args.device), target_act_hat),
                              dim=-1)))

            logits_real = agent.discriminator(
                torch.cat((
                    source_obs.to(args.device),
                    source_act.to(args.device),
                    source_next_obs.to(args.device),
                ),
                          dim=-1)).squeeze()
            logits_fake = agent.discriminator(
                torch.cat((
                    source_obs_hat.detach(),
                    source_act_hat.detach(),
                    source_next_obs_hat.detach(),
                ),
                          dim=-1)).squeeze()

            try:
                # logits_real can be zero-dimensional array
                logits = torch.cat((logits_real, logits_fake), dim=0)
            except:
                continue
            labels = torch.cat((source_domain_ids, target_domain_ids),
                               dim=0)[:, 0].squeeze()
            entropy_loss = -logit_bernoulli_entropy(logits)
            discriminator_loss = sigmoid_cross_entropy_with_logits(
                logits, labels.to(args.device))
            discriminator_loss += 0.0005 * entropy_loss

            accuracy = calc_accuracy(logits, labels.to(args.device))
            epoch_metrics_dict["discriminator_loss"].append(
                discriminator_loss.detach().cpu().numpy())
            epoch_metrics_dict["discriminator_accuracy"].append(
                accuracy.detach().cpu().numpy())

            optimizer_dict["discriminator"].zero_grad()
            discriminator_loss.backward()
            optimizer_dict["discriminator"].step()

            # State map && Action map ----------
            logits_fake = agent.discriminator(
                torch.cat((
                    source_obs_hat,
                    source_act_hat,
                    source_next_obs_hat,
                ),
                          dim=-1)).squeeze()
            labels = target_domain_ids[:, 1].squeeze()
            adversarial_loss = sigmoid_cross_entropy_with_logits(
                logits_fake, labels.to(args.device))
            bc_loss = mse(target_act_hat, target_act.to(args.device))
            policy_loss = bc_loss + args.adversarial_coef * adversarial_loss

            epoch_metrics_dict["adversarial_loss"].append(
                adversarial_loss.detach().cpu().numpy())
            epoch_metrics_dict["bc_loss"].append(
                bc_loss.detach().cpu().numpy())
            epoch_metrics_dict["total_loss"].append(
                policy_loss.detach().cpu().numpy())

            optimizer_dict["state_map"].zero_grad()
            optimizer_dict["action_map"].zero_grad()
            policy_loss.backward()
            optimizer_dict["state_map"].step()
            optimizer_dict["action_map"].step()

        if calc_align_score:
            assert val_loader is not None
            for batch in val_loader:
                obs = batch["observations"]
                domain_ids = batch["domain_ids"]

                source_flag = domain_ids[..., 0] > 0.5
                target_obs = obs[~source_flag]
                source_obs_hat = agent.state_map(target_obs.to(args.device))

                alignment_score, true_source_state = calc_alignment_score(
                    target_states=target_obs,
                    source_states_hat=source_obs_hat,
                    apply_shift=args.shift,
                    device=args.device,
                )
                epoch_metrics_dict['alignment_score'].append(alignment_score)
                true_source_states.append(
                    true_source_state.detach().cpu().numpy())
                predicted_source_states.append(
                    source_obs_hat.detach().cpu().numpy())

        for metric in metrics:
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s += "Loss | "
        s += f"BC: {epoch_metrics_dict['bc_loss']:.3f} | "
        s += f"Adv: {epoch_metrics_dict['adversarial_loss']:.3f} | "
        s += f"Disc: {epoch_metrics_dict['discriminator_loss']:.3f} | "
        s += f"Acc: {epoch_metrics_dict['discriminator_accuracy']:.3f} | "
        if calc_align_score:
            s += f"Align Score: {epoch_metrics_dict['alignment_score']:.3f} | "
        s += f"Total: {epoch_metrics_dict['total_loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metrics(epoch_metrics_dict,
                                   epoch=epoch,
                                   prefix='gama')

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')

        if calc_align_score and epoch % args.plot_interval == 0:
            plot_alignment(
                true_source_states=true_source_states,
                predicted_source_states=predicted_source_states,
                logdir=model_path.parent,
                epoch=epoch,
            )
