import os
import json
from typing import Union

import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader

from .models.tgmm import TGMMModel, MultiTaskTGMMModel
from .models.em import GaussianMixtureEM
from .models.spectral import GaussianMixtureSpectral
from .evaluation import GMMEvaluator
from .task import (
    IsotropicGaussianMixtureTask,
    AnisotropicGaussianMixtureTask,
    OODIsotropicGaussianMixtureTask,
    SphericalGaussianMixtureTask,
    MultiTaskGaussianMixtureTask,
    concat_task_sample,
)
from .dataset import (
    GaussianMixtureDataset,
    StaticGaussianMixtureDataset,
    check_or_create_static_dataset,
)
from .utils import seed_everything, get_device, StreamingLossMeter


def _init_task_and_model(cfg):
    if cfg.task.type == "IsotropicGaussianMixture":
        task = IsotropicGaussianMixtureTask(
            n_components=cfg.task.n_components,
            dim=cfg.task.dim,
        )
        model = TGMMModel(
            task=task,
            n_positions=cfg.model.n_positions,
            n_embd=cfg.model.n_embd,
            n_layer=cfg.model.n_layer,
            n_head=cfg.model.n_head,
        )
    elif cfg.task.type in (
        "MultiTaskIsotropicGaussianMixture",
        "MultiTaskAnisotropicGaussianMixture",
    ):
        task_cls = (
            IsotropicGaussianMixtureTask
            if cfg.task.type == "MultiTaskIsotropicGaussianMixture"
            else AnisotropicGaussianMixtureTask
        )
        task_list = [
            task_cls(n_components=n, dim=cfg.task.dim) for n in cfg.task.n_components
        ]
        task = MultiTaskGaussianMixtureTask(task_list)
        if cfg.model.model_type == "transformer":
            model_args = {
                "n_positions": cfg.model.n_positions,
                "n_embd": cfg.model.n_embd,
                "n_layer": cfg.model.n_layer,
                "n_head": cfg.model.n_head,
                "is_isotropic": cfg.task.type == "MultiTaskIsotropicGaussianMixture",
            }
        else:
            # Mamba2 arguments, note that naming conventions are indeed different
            model_args = {
                "hidden_size": cfg.model.hidden_size,
                "num_heads": cfg.model.num_heads,
                "num_hidden_layers": cfg.model.num_hidden_layers,
                "head_dim": cfg.model.head_dim,
                "state_size": cfg.model.state_size,
                "n_groups": cfg.model.n_groups,
                "expand": cfg.model.expand,
            }
        model = MultiTaskTGMMModel(
            task=task, model_type=cfg.model.model_type, **model_args
        )
    elif cfg.task.type == "PhaseTransitionGaussianMixture":
        # Use a-b-n configuration from https://arxiv.org/abs/1812.08078
        task = MultiTaskGaussianMixtureTask.abn_config(
            a_s=cfg.task.a_s,
            b=cfg.task.b,
            n=cfg.train.n_sample,
        )
        model = TGMMModel(
            task=task.tasks[0],  # Doesn't matter which specific task is
            n_positions=cfg.model.n_positions,
            n_embd=cfg.model.n_embd,
            n_layer=cfg.model.n_layer,
            n_head=cfg.model.n_head,
        )
    else:
        raise ValueError(cfg.task.type)
    return task, model


class EvaluationHelper(object):
    r"""Helper class supporting both static and dynamic evaluation."""

    @staticmethod
    def _get_dataset_name(task, n_sample):
        return f"{task.n_components}_{n_sample}"

    def _get_prefix(self, subtask, n_sample):
        if self.cfg.task.type == "PhaseTransitionGaussianMixture":
            prefix = f"a_{subtask.a:.4f}-b_{subtask.b:.4f}"
        else:
            prefix = f"K_{subtask.n_components}-N_{n_sample}"
        return prefix

    def __init__(self, cfg, task: MultiTaskGaussianMixtureTask, device):
        self.cfg = cfg
        self.task = task
        self.n_samples = (
            [cfg.eval.n_sample]
            if isinstance(cfg.eval.n_sample, int)
            else cfg.eval.n_sample
        )
        self.max_n_components = self.task.max_n_components
        self.device = device
        self.eval_strategy = cfg.eval.strategy
        if self.eval_strategy == "static":
            self._dataset_dict = self._init_static_datasets()
            self._maybe_load_from_external()
            self._evaluator_dict = self._init_static_evaluators()
        else:
            self._dataset_dict = None
            self._evaluator_dict = None
        self._eval_result_cache = {}  # For storing baseline methods

    def _init_static_datasets(self):
        dataset_dict = {}
        dataset_size = self.cfg.eval.batch_size  # Concept overload here
        for task in self.task.tasks:
            for n_sample in self.n_samples:
                dataset = StaticGaussianMixtureDataset(
                    task=task,
                    n_sample=n_sample,
                    dataset_size=dataset_size,
                )
                dataset_name = self._get_dataset_name(task, n_sample)
                dataset_dict[dataset_name] = dataset
        return dataset_dict

    def _maybe_load_from_external(self):
        dataset_path = os.path.abspath(self.cfg.eval.static_dataset_path)
        if dataset_path is not None:
            check_or_create_static_dataset(dataset_path)
            with open(dataset_path, "r") as f:
                datasets = json.load(f)
                for dataset_name, dataset in datasets.items():
                    assert dataset_name in self._dataset_dict
                    self._dataset_dict[dataset_name].load_from(
                        dataset, device=self.device
                    )

    def _init_static_evaluators(self):
        evaluator_dict = {}
        for task in self.task.tasks:
            for n_sample in self.n_samples:
                dataset_name = self._get_dataset_name(task, n_sample)
                ground_truth = self._dataset_dict[dataset_name].sample
                evaluator = GMMEvaluator(
                    task=task,
                    ground_truth=ground_truth,
                )
                evaluator_dict[dataset_name] = evaluator
        return evaluator_dict

    def _get_static_task_sample_and_evaluator(self, subtask, n_sample, dataset_size):
        dataset_name = self._get_dataset_name(subtask, n_sample)
        task_sample = self._dataset_dict[
            dataset_name
        ]._sample  # No need to call __getitem__
        evaluator = self._evaluator_dict[dataset_name]
        return task_sample, evaluator

    def _get_dynamic_task_sample_and_evaluator(self, subtask, n_sample, dataset_size):
        if self.cfg.eval.ood_perturbation_scale > 0.0:
            subtask = OODIsotropicGaussianMixtureTask.from_id_task(
                subtask, perturbation_scale=self.cfg.eval.ood_perturbation_scale
            )
        task_sample = subtask.sample(
            n_sample=n_sample,
            batch_size=dataset_size,
            gen_mask=False,
        ).to(self.device)
        task_sample_for_eval = task_sample.to("cpu")
        evaluator = GMMEvaluator(task=subtask, ground_truth=task_sample_for_eval)
        return task_sample, evaluator

    def _get_task_sample_and_evaluator(self, subtask, n_sample, dataset_size):
        if self.eval_strategy == "static":
            return self._get_static_task_sample_and_evaluator(
                subtask, n_sample, dataset_size
            )
        else:
            return self._get_dynamic_task_sample_and_evaluator(
                subtask, n_sample, dataset_size
            )

    def _get_baseline_evaluation(
        self,
        subtask,
        task_sample,
        evaluator,
        n_sample,
    ):
        prefix = self._get_prefix(subtask, n_sample)
        gmm_em = GaussianMixtureEM(
            n_components=subtask.n_components,
            n_features=subtask.dim,
            verbose=self.cfg.train.verbose,
            learnable_covariance=bool(
                self.cfg.task.type == "MultiTaskAnisotropicGaussianMixture"
            ),
        )
        gmm_spectral = GaussianMixtureSpectral(
            n_components=subtask.n_components,
            verbose=self.cfg.train.verbose,
            # n_repeat=100,
            # n_iteration=20,
        )
        alpha_est_em, mu_est_em, scale_est, iter_em = gmm_em.fit_batch(
            task_sample.sample.cpu()
        )
        alpha_est_spectral, mu_est_spectral, _ = gmm_spectral.fit_batch(
            task_sample.sample.cpu()
        )
        eval_results_em = evaluator(
            mu_est=mu_est_em,
            alpha_est=alpha_est_em,
            scale_est=(
                scale_est
                if self.cfg.task.type == "MultiTaskAnisotropicGaussianMixture"
                else None
            ),
            in_sample_eval=True,
        )
        eval_results_spectral = evaluator(
            mu_est=mu_est_spectral,
            alpha_est=alpha_est_spectral,
            in_sample_eval=True,
        )
        mean_iter_em = iter_em.mean().item()
        baseline_summary = {
            f"{prefix}.em_summary": eval_results_em.summary_for_wandb(),
            f"{prefix}.spectral_summary": eval_results_spectral.summary_for_wandb(),
            # Some auxiliary metrics
            f"{prefix}.em_iter": mean_iter_em,
        }
        return baseline_summary

    def _evaluate(
        self,
        subtask,
        n_sample,
        dataset_size,
        model,
        loss_meter,
        step,
    ):
        prefix = self._get_prefix(subtask, n_sample)
        with torch.no_grad():
            task_sample, evaluator = self._get_task_sample_and_evaluator(
                subtask, n_sample, dataset_size
            )
            task_sample.pad(self.max_n_components)
            model_output = model(task_sample)
            # Adjust mask manually
            model_output.alpha_est = model_output.alpha_est[:, : subtask.n_components]
            model_output.mu_est = model_output.mu_est[:, : subtask.n_components, :]
            if model_output.scale_est is not None:
                model_output.scale_est = model_output.scale_est[
                    :, : subtask.n_components, :
                ]
            eval_results_tgmm = evaluator(
                mu_est=model_output.mu_est.cpu(),
                alpha_est=F.softmax(model_output.alpha_est.cpu(), dim=-1),
                scale_est=(
                    model_output.scale_est.cpu()
                    if model_output.scale_est is not None
                    else None
                ),
                in_sample_eval=True,
            )
            alpha_loss, mu_loss, scale_loss, total_loss = loss_meter.compute()
            summary = {
                "step": step,
                f"{prefix}.tgmm_summary": eval_results_tgmm.summary_for_wandb(),
                # Some auxiliary metrics
                f"{prefix}.alpha_loss": (
                    alpha_loss.cpu().item() if alpha_loss is not None else None
                ),
                f"{prefix}.mu_loss": (
                    mu_loss.cpu().item() if mu_loss is not None else None
                ),
                f"{prefix}.scale_loss": (
                    scale_loss.cpu().item() if scale_loss is not None else None
                ),
                f"{prefix}.total_loss": (
                    total_loss.cpu().item() if total_loss is not None else None
                ),
            }
            if self.cfg.task.type != "PhaseTransitionGaussianMixture":
                summary.update(
                    self._get_baseline_evaluation(
                        subtask,
                        task_sample,
                        evaluator,
                        n_sample,
                    )
                )
        return summary

    def evaluate(
        self,
        model: MultiTaskTGMMModel,
        loss_meter: StreamingLossMeter,
        step,
    ):
        model.eval()
        summary_dict = {}
        for subtask in self.task.tasks:
            for n_sample in self.n_samples:
                summary_dict.update(
                    self._evaluate(
                        subtask,
                        n_sample,
                        dataset_size=self.cfg.eval.batch_size,
                        model=model,
                        loss_meter=loss_meter,
                        step=step,
                    )
                )
        model.train()
        return summary_dict


def train(cfg, device_id, name: str = None):
    r"""Training pipeline"""
    device = get_device(device_id)
    seed_everything(cfg.train.seed)
    task, model = _init_task_and_model(cfg)
    model = model.to(device)
    dataset = GaussianMixtureDataset(
        batch_size=cfg.train.batch_size, task=task, n_sample=cfg.train.n_sample
    )
    loader = DataLoader(
        dataset=dataset,
        batch_size=1,  # Batch size is handled inside dataset iterator
        num_workers=4,
        pin_memory=True,
        collate_fn=concat_task_sample,
    )
    it = iter(loader)
    # Initialize evaluation pipeline
    evaluation_helper = EvaluationHelper(
        task=task,
        cfg=cfg,
        device=device,
    )
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.train.learning_rate,
        weight_decay=cfg.train.weight_decay,
    )
    num_steps = cfg.train.num_train_steps
    loss_meter = StreamingLossMeter(n_metrics=4, window_size=cfg.train.eval_every).to(
        device=device
    )
    model.train()
    pbar = tqdm(range(num_steps)) if cfg.train.verbose else range(num_steps)
    eval_results = []
    for step in pbar:
        if not step % cfg.train.eval_every:
            eval_results.append(evaluation_helper.evaluate(model, loss_meter, step))
        task_sample = next(it).to(device=device, non_blocking=True)
        optimizer.zero_grad()
        model_output = model(task_sample)
        total_loss = model_output.alpha_loss + model_output.mu_loss
        if model_output.scale_loss is not None:
            total_loss += model_output.scale_loss
        loss_meter.update(
            model_output.alpha_loss,
            model_output.mu_loss,
            model_output.scale_loss or torch.zeros_like(total_loss),
            total_loss,
        )
        total_loss.backward()
        optimizer.step()
        if cfg.train.verbose:
            alpha_loss = model_output.alpha_loss.cpu().detach().numpy()
            mu_loss = model_output.mu_loss.cpu().detach().numpy()
            scale_loss = (
                model_output.scale_loss.cpu().detach().numpy()
                if model_output.scale_loss is not None
                else 0.0
            )
            total_loss = total_loss.cpu().detach().numpy()
            pbar.set_description(
                f"alpha_loss: {alpha_loss:.4f}\t"
                f"mu_loss: {mu_loss:.4f}\t"
                f"scale_loss: {scale_loss:.4f}\t"
                f"total_loss: {total_loss:.4f}"
            )
    return eval_results
