from comet_ml import Experiment

# isort: split

import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Union

import comet_ml
import numpy as np
import torch
from torch.utils.data import DataLoader

from common.dail.models import DAILAgent
from common.utils.process_dataset import get_obs_converter

logger = logging.getLogger(__name__)


@dataclass
class GAMAModelConfig:
    lr: float = 1e-4
    in_dim: int = -1
    hid_dims: List[int] = field(default_factory=list)
    activations: List[Optional[str]] = field(default_factory=list)


@dataclass
class GAMAConfig:
    config: Optional[Path] = None
    domains: List[Dict] = field(default_factory=list)
    adapt_domains: List[Dict] = field(default_factory=list)

    n_domains: int = 2
    add_domain: bool = False
    complex_task: bool = False
    robot: bool = False

    name: Optional[str] = None
    goal: int = 6
    n_tasks: int = -1
    n_traj: int = 1000
    train_ratio: float = 0.9
    device: str = 'cuda:0'
    comet: bool = False

    eval_interval: int = 5
    n_eval_episodes: int = 20
    n_render_episodes: int = 10

    adapt_eval_interval: int = 40
    adapt_n_eval_episodes: int = 50
    adapt_n_traj: int = -1

    image_observation: bool = False
    pretrained: bool = False  # whether to use pretrained image encoder
    image_state_dim: int = 1024
    use_image_decoder: bool = False
    use_coord_conv: bool = False
    image_recon_coef: float = 1.0
    evaluate: bool = True
    evaluate_parallel: bool = False  # avoid bug in parallelization of P2A
    amp: bool = False

    multienv: bool = False
    n_task_ids: Optional[int] = None
    target_goal_id: Optional[int] = None
    task_id_offset_list: Optional[List[int]] = None
    target_task_id_offset: Optional[int] = None

    max_dataset_size: int = 2_000_000
    adversarial_coef: float = 0.5
    batch_size: int = 256
    num_epoch_bc: int = 30
    num_epoch_dynamics: int = 20
    num_epoch_gama: int = 30
    num_epoch_adapt: int = 50
    decode_with_state: bool = False
    use_coord_conv: bool = False

    # set in the script
    max_obs_dim: int = field(init=False)
    max_action_dim: int = field(init=False)
    max_seq_len: int = field(init=False)
    train_goal_ids: List[int] = field(init=False)
    logdir: Union[str, Path] = field(init=False)
    num_task_ids: int = field(init=False)
    source_state_dim: int = field(init=False)
    target_state_dim: int = field(init=False)
    source_action_dim: int = field(init=False)
    target_action_dim: int = field(init=False)

    bc: Dict = field(default_factory=lambda: {
        'num_epoch': 30,
        'batch_size': 256,
        'lr': 1e-3
    })

    h: int = 128
    act: str = 'leaky_relu'

    lr: float = 1e-4
    disc_lr: float = 1e-5

    models: Dict = field(
        default_factory=lambda: {
            'source_policy':
            GAMAModelConfig(lr=1e-4,
                            in_dim=-1,
                            hid_dims=[300, 200, -1],
                            activations=['leaky_relu', 'leaky_relu', None]),
            'target_policy':
            GAMAModelConfig(lr=1e-5,
                            in_dim=-1,
                            hid_dims=[300, 200, -1],
                            activations=['leaky_relu', 'leaky_relu', None]),
            'state_map':
            GAMAModelConfig(lr=1e-4,
                            in_dim=-1,
                            hid_dims=[128, 128, -1],
                            activations=['leaky_relu', 'leaky_relu', None]),
            'action_map':
            GAMAModelConfig(lr=1e-4,
                            in_dim=-1,
                            hid_dims=[128, 128, -1],
                            activations=['leaky_relu', 'leaky_relu', None]),
            'inv_state_map':
            GAMAModelConfig(lr=1e-4,
                            in_dim=-1,
                            hid_dims=[200, 200, -1],
                            activations=['leaky_relu', 'leaky_relu', None]),
            'dynamics_model':
            GAMAModelConfig(lr=1e-4,
                            in_dim=-1,
                            hid_dims=[128, 128, 128, -1],
                            activations=
                            ['leaky_relu', 'leaky_relu', 'leaky_relu', None]),
            'discriminator':
            GAMAModelConfig(lr=1e-5,
                            in_dim=-1,
                            hid_dims=[128, 128, 1],
                            activations=['leaky_relu', 'leaky_relu', None]),
        })


def calc_accuracy(logits, labels):
    mask = (logits > 0.5).long()
    correct = torch.sum(mask * labels + (1 - mask) * (1 - labels))
    return correct / len(logits)


def sigmoid_cross_entropy_with_logits(logits, labels):
    x = logits.unsqueeze(1)
    x = torch.cat((x, torch.zeros_like(x, device=x.device)), dim=-1)
    x, _ = torch.max(x, dim=1)
    x = x - logits * labels + torch.log(1 + torch.exp(-torch.abs(logits)))
    return torch.mean(x)


def logsigmoid(a):
    return -torch.nn.Softplus()(-a)


def logit_bernoulli_entropy(logits):
    ent = (1. - torch.sigmoid(logits)) * logits - logsigmoid(logits)
    return torch.mean(ent)


def mse_loss(source: np.ndarray, target: np.ndarray):
    return ((source - target)**2).mean()


def calc_alignment_score(target_states,
                         source_states_hat,
                         apply_shift: bool = False,
                         device: str = 'cuda:0'):
    target_states = target_states.to(device)
    inverted_states = target_states[..., [1, 0, 3, 2]]
    if apply_shift:
        shift = torch.tensor([6, 6, 0, 0],
                             device=inverted_states.device,
                             dtype=torch.float)
        inverted_states -= shift

    align_score = torch.nn.MSELoss()(inverted_states, source_states_hat).item()
    return align_score, inverted_states


def plot_alignment(
    true_source_states,
    predicted_source_states,
    logdir: Path,
    epoch: int,
    n_plot: int = 1000,
    experiment: Optional[comet_ml.Experiment] = None,
):
    import matplotlib.pyplot as plt
    import numpy as np
    plt.clf()
    true_source_states = np.vstack(true_source_states)[:n_plot]
    predicted_source_states = np.vstack(predicted_source_states)[:n_plot]
    concat_states = np.concatenate(
        (true_source_states, predicted_source_states))
    state_min, state_max = concat_states.min(0), concat_states.max(0)
    plt.xlim([state_min[0], state_max[0]])
    plt.ylim([state_min[1], state_max[1]])

    true_source_pos = true_source_states[..., :2]
    predicted_source_pos = predicted_source_states[..., :2]
    plt.scatter(*true_source_pos.T, label='source', marker='.')
    plt.scatter(*predicted_source_pos.T, label='predicted', marker='.')

    distance = mse_loss(true_source_states, predicted_source_states)
    for k in range(15):
        first = true_source_pos[k]
        second = predicted_source_pos[k]
        plt.plot((first[0], second[0]), (first[1], second[1]),
                 color='black',
                 lw=1)
    plt.legend()
    plt.title(f'(All) {logdir}\n epoch {epoch}: dist={distance:.5f}')
    plt.tight_layout()
    png_path = logdir / 'all-z.png'
    plt.savefig(png_path)
    if experiment:
        experiment.log_image(str(png_path), step=epoch)


def configure_model_params(args):
    goal_dim = args.num_task_ids

    m = args.models
    m.source_policy.in_dim = args.source_state_dim + goal_dim
    m.source_policy.hid_dims[-1] = args.source_action_dim

    m.target_policy.in_dim = args.target_state_dim + goal_dim
    m.target_policy.hid_dims[-1] = args.target_action_dim

    m.state_map.in_dim = args.target_state_dim
    m.state_map.hid_dims[-1] = args.source_state_dim

    if hasattr(args, 'decode_with_state') and args.decode_with_state:
        m.action_map.in_dim = args.source_action_dim + args.target_state_dim
    else:
        m.action_map.in_dim = args.source_action_dim
    m.action_map.hid_dims[-1] = args.target_action_dim

    m.inv_state_map.in_dim = args.source_state_dim
    m.inv_state_map.hid_dims[-1] = args.target_state_dim

    m.dynamics_model.in_dim = args.target_state_dim + args.target_action_dim
    m.dynamics_model.hid_dims[-1] = args.target_state_dim

    m.discriminator.in_dim = args.source_state_dim * 2 + args.source_action_dim

    return args


def process_alignment_dataset(
    source_dataset: Dict[str, np.ndarray],
    target_dataset: Dict[str, np.ndarray],
    num_task_ids: int,
    max_size: Optional[int] = None,
    use_domain_id: bool = True,
) -> Dict[str, torch.Tensor]:

    alignment_dataset = {}

    # Process observations ----------
    source_observations = source_dataset.obs
    target_observations = target_dataset.obs
    source_next_observations = np.array(source_dataset.next_obs)
    target_next_observations = np.array(target_dataset.next_obs)

    source_random_idx = np.arange(len(source_observations))
    target_random_idx = np.arange(len(target_observations))
    np.random.shuffle(source_random_idx)
    np.random.shuffle(target_random_idx)
    if max_size:
        source_random_idx = source_random_idx[:max_size // 2]
        target_random_idx = target_random_idx[:max_size // 2]

    if source_observations.shape[1] != target_observations.shape[1]:
        observation_size = np.max(
            (source_observations.shape[1], target_observations.shape[1]))
        if source_observations.shape[1] < target_observations.shape[1]:
            pad = np.zeros((source_observations.shape[0],
                            observation_size - source_observations.shape[1]))
            source_observations = np.hstack((source_observations, pad))
            source_next_observations = np.hstack(
                (source_next_observations, pad))
        else:
            pad = np.zeros((target_observations.shape[0],
                            observation_size - target_observations.shape[1]))
            target_observations = np.hstack((target_observations, pad))
            target_next_observations = np.hstack(
                (target_next_observations, pad))

    alignment_dataset["observations"] = torch.Tensor(
        np.vstack((source_observations[source_random_idx],
                   target_observations[target_random_idx])).copy())
    alignment_dataset["next_observations"] = torch.Tensor(
        np.vstack((source_next_observations[source_random_idx],
                   target_next_observations[target_random_idx])).copy())

    # Process actions ----------
    source_actions = source_dataset.actions
    target_actions = target_dataset.actions
    source_action_masks = np.ones_like(source_actions)
    target_action_masks = np.ones_like(target_actions)

    if source_actions.shape[1] != target_actions.shape[1]:
        action_size = np.max(
            (source_actions.shape[1], target_actions.shape[1]))
        if source_actions.shape[1] < target_actions.shape[1]:
            pad = np.zeros((source_actions.shape[0],
                            action_size - source_actions.shape[1]))
            source_actions = np.hstack((source_actions, pad))
            source_action_masks = np.hstack((source_action_masks, pad))
        else:
            pad = np.zeros((target_actions.shape[0],
                            action_size - target_actions.shape[1]))
            target_actions = np.hstack((target_actions, pad))
            target_action_masks = np.hstack((target_action_masks, pad))

    alignment_dataset["actions"] = torch.Tensor(
        np.vstack((source_actions[source_random_idx],
                   target_actions[target_random_idx])).copy())
    alignment_dataset["action_masks"] = torch.Tensor(
        np.vstack((source_action_masks[source_random_idx],
                   target_action_masks[target_random_idx])).copy())

    # Process task ids ----------
    source_task_ids = np.eye(num_task_ids)[source_dataset.task_ids]
    target_task_ids = np.eye(num_task_ids)[target_dataset.task_ids]

    alignment_dataset["task_ids"] = torch.Tensor(
        np.vstack((source_task_ids[source_random_idx],
                   target_task_ids[target_random_idx])).copy())

    # Process domain ids ----------
    if use_domain_id:
        source_domain_ids = np.eye(2)[0][None, :].repeat(
            source_observations.shape[0], axis=0)
        target_domain_ids = np.eye(2)[1][None, :].repeat(
            target_observations.shape[0], axis=0)
    else:
        source_domain_ids = np.zeros(
            (1, 2)).repeat(source_observations.shape[0], axis=0)
        target_domain_ids = np.zeros(
            (1, 2)).repeat(target_observations.shape[0], axis=0)
    alignment_dataset["domain_ids"] = torch.Tensor(
        np.vstack((source_domain_ids[source_random_idx],
                   target_domain_ids[target_random_idx])).copy())

    print("Source dataset length:", len(source_random_idx))
    print("Target dataset length:", len(target_random_idx))
    print("Total dataset length:", len(alignment_dataset["observations"]))

    return alignment_dataset


def calc_alignment_score(
    args,
    agent: DAILAgent,
    data_loader: DataLoader,
    experiment: Experiment,
):

    if obs_converter_name := args.domains[0].get('obs_converter'):
        if trans_args := args.domains[0].get('obs_converter_args'):
            obs_converter = get_obs_converter(name=obs_converter_name,
                                              **trans_args)
        else:
            obs_converter = get_obs_converter(name=obs_converter_name)
    else:
        obs_converter = None

    source_s_list = []
    source_s_hat_list = []
    for batch in data_loader:
        target_flag = (batch['domain_ids'][..., -1:] > .5).flatten()
        target_states = batch['observations'][target_flag].to(args.device)
        if len(target_states) == 0:
            continue

        source_states = target_states
        if obs_converter is not None:
            source_states = obs_converter(source_states.cpu().numpy())
            source_s_list.append(source_states)

        with torch.no_grad():
            source_states_hat = agent.state_map(target_states)

        source_s_hat_list.append(source_states_hat.cpu().numpy())

    source_s = np.concatenate(source_s_list)
    source_s_hat = np.concatenate(source_s_hat_list)

    dists = np.linalg.norm(source_s - source_s_hat, axis=-1, keepdims=True)

    source_norm = np.linalg.norm(source_s, axis=-1, keepdims=True)
    source_hat_norm = np.linalg.norm(source_s_hat, axis=-1, keepdims=True)
    norm_two = (source_norm + source_hat_norm) / 2
    norm_all = norm_two.mean()
    dist_norm_all = (dists / norm_all).mean()
    logger.info(f'Latent score = {dist_norm_all:4f}')

    if experiment is not None:
        metrics_dict = {'latent_dist': dist_norm_all}
        experiment.log_metrics(metrics_dict)
