"""
Pre-training agent for Parameterized Behavior Cloning (PBC).

Trains PBC with trajectory embeddings z via FiLM conditioning.
Uses ParameterizedBatch and simple MSE loss.
"""

import logging
import wandb
import numpy as np
import torch
from tqdm import tqdm
import os
import json

log = logging.getLogger(__name__)
from util.timer import Timer
from agent.pretrain.train_agent import PreTrainAgent, batch_to_device
from agent.dataset.parameterized_sequence import ParameterizedBatch


def batch_to_device_parameterized(batch, device):
    """Move a ParameterizedBatch to device."""
    return ParameterizedBatch(
        actions=batch.actions.to(device),
        conditions={k: v.to(device) for k, v in batch.conditions.items()},
        z_embedding=batch.z_embedding.to(device),
    )


class TrainParameterizedBCAgent(PreTrainAgent):
    """
    Training agent for Parameterized BC model.

    Handles ParameterizedBatch which includes (actions, conditions, z_embedding).
    Uses simple MSE loss without diffusion noise/timesteps.
    """

    def __init__(self, cfg):
        super().__init__(cfg)

        # Initialize loss history tracking
        self.train_loss_history = []
        self.val_loss_history = []
        self.epoch_history = []

        # Track best model
        self.best_loss = float("inf")
        self.best_epoch = 0

    def run(self):
        timer = Timer()
        self.epoch = 1
        cnt_batch = 0

        # Create epoch progress bar
        epoch_pbar = tqdm(range(self.n_epochs), desc="Epochs", position=0)

        for _ in epoch_pbar:
            epoch_pbar.set_description(f"Epoch {self.epoch}/{self.n_epochs}")

            # Train
            loss_train_epoch = []

            for batch_train in self.dataloader_train:
                if self.dataset_train.device == "cpu":
                    batch_train = batch_to_device_parameterized(
                        batch_train, self.device
                    )

                self.model.train()

                # PBC loss: MSE(predicted_trajectory, ground_truth_trajectory)
                loss_train = self.model.loss(
                    batch_train.actions,
                    batch_train.conditions,
                    batch_train.z_embedding,
                )
                loss_train.backward()
                loss_train_epoch.append(loss_train.item())

                self.optimizer.step()
                self.optimizer.zero_grad()

                # Update EMA
                if cnt_batch % self.update_ema_freq == 0:
                    self.step_ema()
                cnt_batch += 1

            loss_train = np.mean(loss_train_epoch)

            # Validate
            loss_val_epoch = []
            if self.dataloader_val is not None and self.epoch % self.val_freq == 0:
                self.model.eval()

                for batch_val in self.dataloader_val:
                    if self.dataset_val.device == "cpu":
                        batch_val = batch_to_device_parameterized(
                            batch_val, self.device
                        )

                    with torch.no_grad():
                        loss_val = self.model.loss(
                            batch_val.actions,
                            batch_val.conditions,
                            batch_val.z_embedding,
                        )
                    loss_val_epoch.append(loss_val.item())

                self.model.train()
            loss_val = np.mean(loss_val_epoch) if len(loss_val_epoch) > 0 else None

            # Update learning rate
            self.lr_scheduler.step()

            # Save model periodically
            if self.epoch % self.save_model_freq == 0 or self.epoch == self.n_epochs:
                self.save_model()

            # Save best model
            if loss_train < self.best_loss:
                self.best_loss = loss_train
                self.best_epoch = self.epoch
                self.save_model(is_best=True)

            # Update progress bar
            postfix_dict = {"train_loss": f"{loss_train:.4f}", "time": f"{timer():.1f}s"}
            if loss_val is not None:
                postfix_dict["val_loss"] = f"{loss_val:.4f}"
            epoch_pbar.set_postfix(postfix_dict)

            # Store loss history
            self.epoch_history.append(self.epoch)
            self.train_loss_history.append(float(loss_train))
            self.val_loss_history.append(
                float(loss_val) if loss_val is not None else None
            )

            # Log
            if self.epoch % self.log_freq == 0:
                if self.use_wandb:
                    if loss_val is not None:
                        wandb.log(
                            {"loss - val": loss_val}, step=self.epoch, commit=False
                        )
                    wandb.log(
                        {"loss - train": loss_train},
                        step=self.epoch,
                        commit=True,
                    )

            self.epoch += 1

        epoch_pbar.close()

        # Save loss history
        self.save_loss_history()

    def save_loss_history(self):
        """Save training and validation loss history to files."""

        # Save as numpy arrays
        loss_npz_path = os.path.join(self.checkpoint_dir, "loss_history.npz")
        np.savez(
            loss_npz_path,
            epochs=np.array(self.epoch_history),
            train_loss=np.array(self.train_loss_history),
            val_loss=np.array(
                [v if v is not None else np.nan for v in self.val_loss_history]
            ),
        )
        print(f"Saved loss history to {loss_npz_path}")

        # Save as JSON
        loss_json_path = os.path.join(self.checkpoint_dir, "loss_history.json")
        loss_dict = {
            "epochs": self.epoch_history,
            "train_loss": self.train_loss_history,
            "val_loss": self.val_loss_history,
            "final_train_loss": (
                self.train_loss_history[-1] if self.train_loss_history else None
            ),
            "final_val_loss": (
                [v for v in self.val_loss_history if v is not None][-1]
                if any(v is not None for v in self.val_loss_history)
                else None
            ),
            "best_train_loss": (
                min(self.train_loss_history) if self.train_loss_history else None
            ),
            "best_val_loss": (
                min(v for v in self.val_loss_history if v is not None)
                if any(v is not None for v in self.val_loss_history)
                else None
            ),
            "best_epoch": self.best_epoch,
            "best_loss": self.best_loss,
        }
        with open(loss_json_path, "w") as f:
            json.dump(loss_dict, f, indent=2)
        print(f"Saved loss history (JSON) to {loss_json_path}")

        # Save summary
        summary_path = os.path.join(self.checkpoint_dir, "training_summary.txt")
        with open(summary_path, "w") as f:
            f.write("Parameterized BC Training Summary\n")
            f.write("=" * 50 + "\n")
            f.write(f"Total Epochs Trained: {len(self.epoch_history)}\n")
            if loss_dict["final_train_loss"]:
                f.write(f"Final Train Loss: {loss_dict['final_train_loss']:.6f}\n")
            if loss_dict["final_val_loss"]:
                f.write(f"Final Val Loss: {loss_dict['final_val_loss']:.6f}\n")
            if loss_dict["best_train_loss"]:
                f.write(
                    f"Best Train Loss: {loss_dict['best_train_loss']:.6f} (Epoch {self.best_epoch})\n"
                )
            f.write(f"Best model saved as: best_model.pt\n")
        print(f"Saved training summary to {summary_path}")
