import logging
import os
from pathlib import Path
from typing import Callable, List, Optional

import comet_ml
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

from common.dail.models import DAILAgent
from common.dail.utils import (calc_accuracy, calc_alignment_score,
                               logit_bernoulli_entropy, plot_alignment,
                               sigmoid_cross_entropy_with_logits)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def train_source_policy(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train source domain policy with BC
    # ----------------------------------------
    print("Start training expert policy...")
    # TODO load from source_only_dataset (with correct batch_size)
    optimizer = torch.optim.Adam(
        agent.source_policy.parameters(),
        lr=omega_args.bc.lr,
        betas=[0.9, 0.999],
    )
    mse = nn.MSELoss()
    metrics = ["loss"]
    for epoch in range(1, epochs + 1):
        s = f"Epoch {epoch:3d}/{epochs} | "
        epoch_metrics_dict = {metric: [] for metric in metrics}
        for obs, task_id, domain_id, act, _, _ in tqdm(
                train_loader, bar_format="{l_bar}{bar:50}{r_bar}"):
            # select source domain data
            source_flag = domain_id[..., 0] > 0.5
            obs = obs[source_flag]
            task_id = task_id[source_flag]
            act = act[source_flag]

            obs = torch.cat((obs, task_id), dim=-1)

            act_pred = agent.source_policy(obs.to(args.device))
            loss = mse(act.to(args.device), act_pred)

            epoch_metrics_dict["loss"].append(loss.detach().cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for metric in metrics:
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s += f"Loss: {epoch_metrics_dict['loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metric(name='source_bc_loss',
                                  value=epoch_metrics_dict['loss'],
                                  epoch=epoch)

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')


def train_target_dynamics_model(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train target domain dynamics model
    # ----------------------------------------

    print("Start training dynamics model...")
    # TODO load from target_only_dataset (with correct batch_size)
    optimizer = torch.optim.Adam(
        agent.dynamics_model.parameters(),
        lr=omega_args.models.dynamics_model.lr,
        betas=[0.9, 0.999],
    )
    mse = nn.MSELoss()
    metrics = ["loss"]
    for epoch in range(1, epochs + 1):
        s = f"Epoch {epoch:3d}/{epochs} | "
        epoch_metrics_dict = {metric: [] for metric in metrics}
        for obs, task_id, domain_id, act, next_obs, _ in tqdm(
                train_loader, bar_format="{l_bar}{bar:50}{r_bar}"):
            # select source domain data
            target_flag = domain_id[..., 1] > 0.5
            obs = obs[target_flag]
            act = act[target_flag]
            next_obs = next_obs[target_flag]

            # obs, act, next_obs, task_id, domain_id, act_mask = batch
            obs_act = torch.cat((obs, act), dim=-1)

            next_obs_pred = agent.dynamics_model(obs_act.to(args.device))
            loss = mse(next_obs_pred, next_obs.to(args.device))

            epoch_metrics_dict["loss"].append(loss.detach().cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for metric in metrics:
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s += f"Loss: {epoch_metrics_dict['loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metric(name='dyn_loss',
                                  value=epoch_metrics_dict['loss'],
                                  epoch=epoch)

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')


def train_gama(
    args,
    omega_args: DictConfig,
    agent: DAILAgent,
    epochs: int,
    train_loader: DataLoader,
    model_path: Path,
    val_loader: Optional[DataLoader] = None,
    experiment: Optional[comet_ml.Experiment] = None,
):
    # ----------------------------------------
    # Train MDP Alignment with GAMA
    # ----------------------------------------

    print("Start training MDP Alignment...")
    calc_align_score = hasattr(args,
                               'calc_align_score') and args.calc_align_score

    optimizer_dict = {
        "discriminator":
        torch.optim.Adam(agent.discriminator.parameters(),
                         lr=omega_args.models.discriminator.lr,
                         betas=[0.9, 0.999]),
        "state_map":
        torch.optim.Adam(agent.state_map.parameters(),
                         lr=omega_args.models.state_map.lr,
                         betas=[0.9, 0.999]),
        "action_map":
        torch.optim.Adam(agent.action_map.parameters(),
                         lr=omega_args.models.action_map.lr,
                         betas=[0.9, 0.999]),
    }
    metrics = [
        "discriminator_loss",
        "discriminator_accuracy",
        "adversarial_loss",
        "bc_loss",
        'alignment_score',
        "total_loss",
    ]
    mse = nn.MSELoss()

    for epoch in range(1, epochs + 1):
        s = f"Epoch {epoch:3d}/{epochs} | "
        epoch_metrics_dict = {metric: [] for metric in metrics}
        true_source_states = []
        predicted_source_states = []
        for obs, task_id, domain_id, act, next_obs, _ in tqdm(
                train_loader,
                bar_format="{l_bar}{bar:50}{r_bar}",
        ):
            source_flag = domain_id[..., 0] > 0.5
            source_obs = obs[source_flag][..., :args.source_state_dim]
            source_act = act[source_flag][..., :args.source_action_dim]
            source_task_id = task_id[source_flag]
            source_domain_id = domain_id[source_flag]
            source_next_obs = next_obs[source_flag][
                ..., :args.source_state_dim]

            target_obs = obs[~source_flag][..., :args.target_state_dim]
            target_act = act[~source_flag][..., :args.target_action_dim]
            target_task_id = task_id[~source_flag]
            target_domain_id = domain_id[~source_flag]
            target_next_obs = next_obs[~source_flag][
                ..., :args.target_state_dim]

            # Discriminator ----------
            source_obs_hat = agent.state_map(target_obs.to(args.device))
            source_act_hat = agent.source_policy(
                torch.cat((source_obs_hat, target_task_id.to(args.device)),
                          dim=-1))
            target_act_hat = agent.action_map(source_act_hat)
            source_next_obs_hat = agent.state_map(
                agent.dynamics_model(
                    torch.cat((target_obs.to(args.device), target_act_hat),
                              dim=-1)))
            if args.task_cond:
                logits_real = agent.discriminator(
                    torch.cat((
                        source_obs.to(args.device),
                        source_act.to(args.device),
                        source_next_obs.to(args.device),
                        source_task_id.to(args.device),
                    ),
                              dim=-1)).squeeze()
                logits_fake = agent.discriminator(
                    torch.cat((
                        source_obs_hat.detach(),
                        source_act_hat.detach(),
                        source_next_obs_hat.detach(),
                        target_task_id.to(args.device),
                    ),
                              dim=-1)).squeeze()
            else:
                logits_real = agent.discriminator(
                    torch.cat((
                        source_obs.to(args.device),
                        source_act.to(args.device),
                        source_next_obs.to(args.device),
                    ),
                              dim=-1)).squeeze()
                logits_fake = agent.discriminator(
                    torch.cat((
                        source_obs_hat.detach(),
                        source_act_hat.detach(),
                        source_next_obs_hat.detach(),
                    ),
                              dim=-1)).squeeze()

            try:
                # logits_real can be zero-dimentional array
                logits = torch.cat((logits_real, logits_fake), dim=0)
            except:
                continue
            labels = torch.cat((source_domain_id, target_domain_id),
                               dim=0)[:, 0].squeeze()
            entropy_loss = -logit_bernoulli_entropy(logits)
            discriminator_loss = sigmoid_cross_entropy_with_logits(
                logits, labels.to(args.device))
            discriminator_loss += 0.0005 * entropy_loss

            accuracy = calc_accuracy(logits, labels.to(args.device))
            epoch_metrics_dict["discriminator_loss"].append(
                discriminator_loss.detach().cpu().numpy())
            epoch_metrics_dict["discriminator_accuracy"].append(
                accuracy.detach().cpu().numpy())

            optimizer_dict["discriminator"].zero_grad()
            discriminator_loss.backward()
            optimizer_dict["discriminator"].step()

            # State map && Action map ----------
            if args.task_cond:
                logits_fake = agent.discriminator(
                    torch.cat((
                        source_obs_hat,
                        source_act_hat,
                        source_next_obs_hat,
                        target_task_id.to(args.device),
                    ),
                              dim=-1)).squeeze()
            else:
                logits_fake = agent.discriminator(
                    torch.cat((
                        source_obs_hat,
                        source_act_hat,
                        source_next_obs_hat,
                    ),
                              dim=-1)).squeeze()
            labels = target_domain_id[:, 1].squeeze()
            adversarial_loss = sigmoid_cross_entropy_with_logits(
                logits_fake, labels.to(args.device))
            bc_loss = mse(target_act_hat, target_act.to(args.device))
            policy_loss = bc_loss + args.adversarial_coef * adversarial_loss

            epoch_metrics_dict["adversarial_loss"].append(
                adversarial_loss.detach().cpu().numpy())
            epoch_metrics_dict["bc_loss"].append(
                bc_loss.detach().cpu().numpy())
            epoch_metrics_dict["total_loss"].append(
                policy_loss.detach().cpu().numpy())

            optimizer_dict["state_map"].zero_grad()
            optimizer_dict["action_map"].zero_grad()
            policy_loss.backward()
            optimizer_dict["state_map"].step()
            optimizer_dict["action_map"].step()

        if calc_align_score:
            assert val_loader is not None
            for obs, _, domain_id, _, _, _ in val_loader:
                source_flag = domain_id[..., 0] > 0.5
                target_obs = obs[~source_flag]
                source_obs_hat = agent.state_map(target_obs.to(args.device))

                alignment_score, true_source_state = calc_alignment_score(
                    target_states=target_obs,
                    source_states_hat=source_obs_hat,
                    apply_shift=args.shift,
                    device=args.device,
                )
                epoch_metrics_dict['alignment_score'].append(alignment_score)
                true_source_states.append(
                    true_source_state.detach().cpu().numpy())
                predicted_source_states.append(
                    source_obs_hat.detach().cpu().numpy())

        for metric in metrics:
            epoch_metrics_dict[metric] = np.mean(epoch_metrics_dict[metric])

        s += "Loss | "
        s += f"BC: {epoch_metrics_dict['bc_loss']:.3f} | "
        s += f"Adv: {epoch_metrics_dict['adversarial_loss']:.3f} | "
        s += f"Disc: {epoch_metrics_dict['discriminator_loss']:.3f} | "
        s += f"Acc: {epoch_metrics_dict['discriminator_accuracy']:.3f} | "
        if calc_align_score:
            s += f"Align Score: {epoch_metrics_dict['alignment_score']:.3f} | "
        s += f"Total: {epoch_metrics_dict['total_loss']:.3f} | "
        print(s)
        if experiment:
            experiment.log_metrics(epoch_metrics_dict,
                                   epoch=epoch,
                                   prefix='gama')

        torch.save(agent.state_dict(), model_path)
        logger.info(f'Model is saved to {model_path}.')

        if calc_align_score and epoch % args.plot_interval == 0:
            plot_alignment(
                true_source_states=true_source_states,
                predicted_source_states=predicted_source_states,
                logdir=model_path.parent,
                epoch=epoch,
            )
