"""
Pre-training parameterized diffusion policy with trajectory embeddings

Extends TrainDiffusionAgent to handle ParameterizedBatch with z embeddings.
"""

import logging
import wandb
import numpy as np
import torch
from tqdm import tqdm
import os
import json

log = logging.getLogger(__name__)
from util.timer import Timer
from agent.pretrain.train_agent import PreTrainAgent, batch_to_device
from agent.dataset.parameterized_sequence import ParameterizedBatch
from agent.dataset.style_sequence import StyleBatch


def batch_to_device_generic(batch, device):
    """Move a batch (ParameterizedBatch, StyleBatch, or ControlPointBatch) to device."""
    batch_class = type(batch)
    return batch_class(
        actions=batch.actions.to(device),
        conditions={k: v.to(device) for k, v in batch.conditions.items()},
        z_embedding=batch.z_embedding.to(device)
    )


class TrainParameterizedDiffusionAgent(PreTrainAgent):
    """
    Training agent for parameterized diffusion model.

    Handles ParameterizedBatch which includes (actions, conditions, z_embedding).
    Passes z_embedding to the diffusion model during training.
    Supports Classifier-Free Guidance (CFG) training via dropout.
    """

    def __init__(self, cfg):
        super().__init__(cfg)

        # Initialize loss history tracking
        self.train_loss_history = []
        self.val_loss_history = []
        self.epoch_history = []

        # Track best model
        self.best_loss = float('inf')
        self.best_epoch = 0

        # CFG settings
        self.cfg_enabled = getattr(cfg, 'cfg', None) is not None and cfg.cfg.get('enabled', False)
        self.cfg_dropout_prob = cfg.cfg.get('dropout_prob', 0.1) if self.cfg_enabled else 0.0

        if self.cfg_enabled:
            log.info(f"CFG training enabled with dropout_prob={self.cfg_dropout_prob}")

    def compute_loss(self, batch, use_cfg_dropout=True):
        """
        Compute loss for a ParameterizedBatch.

        batch: ParameterizedBatch(actions, conditions, z_embedding)
            - actions: (B, horizon_steps, action_dim)
            - conditions: dict with 'state': (B, cond_steps, obs_dim)
            - z_embedding: (B, latent_dim)
        use_cfg_dropout: whether to apply CFG dropout (disabled for evaluation)
        """
        actions = batch.actions
        conditions = batch.conditions
        z_embedding = batch.z_embedding

        # Sample random timesteps
        batch_size = len(actions)
        t = torch.randint(
            0, self.model.denoising_steps, (batch_size,), device=actions.device
        ).long()

        # Forward process: add noise
        noise = torch.randn_like(actions, device=actions.device)
        actions_noisy = self.model.q_sample(x_start=actions, t=t, noise=noise)

        # Apply CFG dropout if enabled
        if self.cfg_enabled and use_cfg_dropout:
            # Sample Bernoulli mask for each sample in batch
            dropout_mask = torch.bernoulli(
                torch.full((batch_size,), self.cfg_dropout_prob, device=actions.device)
            ).bool()

            # Replace z with None (will use z_empty) for dropped samples
            if dropout_mask.any():
                # Create a copy and set dropped samples to None by passing None
                # The network will use z_empty for these samples
                z_used = z_embedding.clone()
                z_used[dropout_mask] = 0.0  # Zero out dropped samples (network uses z_empty buffer)

                # We need to pass z=None for dropped samples, but that's tricky with batching
                # Instead, we directly set those embeddings to the z_empty value
                # Get z_empty from the network
                z_empty = self.model.network.z_empty.expand(dropout_mask.sum(), -1)
                z_used[dropout_mask] = z_empty
            else:
                z_used = z_embedding
        else:
            z_used = z_embedding

        # Predict with network (pass z embedding)
        predicted = self.model.network(actions_noisy, t, cond=conditions, z=z_used)

        # Compute loss
        if self.model.predict_epsilon:
            loss = torch.nn.functional.mse_loss(predicted, noise, reduction="mean")
        else:
            loss = torch.nn.functional.mse_loss(predicted, actions, reduction="mean")

        return loss

    def run(self):

        timer = Timer()
        self.epoch = 1
        cnt_batch = 0

        # Calculate total steps for progress tracking
        total_steps = self.n_epochs * len(self.dataloader_train)

        # Create epoch progress bar
        epoch_pbar = tqdm(range(self.n_epochs), desc="Epochs", position=0)

        for _ in epoch_pbar:
            # Update epoch description
            epoch_pbar.set_description(f"Epoch {self.epoch}/{self.n_epochs}")

            # train
            loss_train_epoch = []

            for batch_train in self.dataloader_train:
                if self.dataset_train.device == "cpu":
                    # Move batch to device (works for ParameterizedBatch, StyleBatch, ControlPointBatch)
                    batch_train = batch_to_device_generic(batch_train, self.device)

                self.model.train()
                loss_train = self.compute_loss(batch_train)
                loss_train.backward()
                loss_train_epoch.append(loss_train.item())

                self.optimizer.step()
                self.optimizer.zero_grad()

                # update ema
                if cnt_batch % self.update_ema_freq == 0:
                    self.step_ema()
                cnt_batch += 1

            loss_train = np.mean(loss_train_epoch)

            # validate (always use full z for fair comparison)
            loss_val_epoch = []
            if self.dataloader_val is not None and self.epoch % self.val_freq == 0:
                self.model.eval()

                for batch_val in self.dataloader_val:
                    if self.dataset_val.device == "cpu":
                        # Move batch to device (works for ParameterizedBatch, StyleBatch, ControlPointBatch)
                        batch_val = batch_to_device_generic(batch_val, self.device)

                    with torch.no_grad():
                        # Disable CFG dropout for evaluation (use_cfg_dropout=False)
                        loss_val = self.compute_loss(batch_val, use_cfg_dropout=False)
                    loss_val_epoch.append(loss_val.item())

                self.model.train()
            loss_val = np.mean(loss_val_epoch) if len(loss_val_epoch) > 0 else None

            # update lr
            self.lr_scheduler.step()

            # save model periodically
            if self.epoch % self.save_model_freq == 0 or self.epoch == self.n_epochs:
                self.save_model()
                # Save loss history every time we save model weights
                self.save_loss_history()

            # Save best model (based on training loss)
            if loss_train < self.best_loss:
                self.best_loss = loss_train
                self.best_epoch = self.epoch
                self.save_model(is_best=True)

            # Update epoch progress bar with metrics
            postfix_dict = {'train_loss': f'{loss_train:.4f}',
                           'time': f'{timer():.1f}s'}
            if loss_val is not None:
                postfix_dict['val_loss'] = f'{loss_val:.4f}'
            epoch_pbar.set_postfix(postfix_dict)

            # Record losses
            self.train_loss_history.append(loss_train)
            if loss_val is not None:
                self.val_loss_history.append(loss_val)
            self.epoch_history.append(self.epoch)

            # wandb
            if self.use_wandb:
                if loss_val is not None:
                    wandb.log(
                        {"loss - val": loss_val}, step=self.epoch, commit=False
                    )
                wandb.log(
                    {
                        "loss - train": loss_train,
                        "lr": self.optimizer.param_groups[0]["lr"],
                    },
                    step=self.epoch,
                )

            self.epoch += 1

        # Close epoch progress bar
        epoch_pbar.close()

        # Save loss history (both formats for compatibility)
        self.save_loss_history()

    def save_loss_history(self):
        """Save training and validation loss history to files (same format as TrainDiffusionAgent)"""

        # Save as numpy arrays (primary format for comparison)
        loss_npz_path = os.path.join(self.checkpoint_dir, "loss_history.npz")
        np.savez(
            loss_npz_path,
            epochs=np.array(self.epoch_history),
            train_loss=np.array(self.train_loss_history),
            val_loss=np.array([v if v is not None else np.nan for v in self.val_loss_history])
        )
        print(f"Saved loss history to {loss_npz_path}")

        # Also save as JSON for easy inspection
        loss_json_path = os.path.join(self.checkpoint_dir, "loss_history.json")
        loss_dict = {
            "epochs": self.epoch_history,
            "train_loss": self.train_loss_history,
            "val_loss": self.val_loss_history,
            "final_train_loss": self.train_loss_history[-1] if self.train_loss_history else None,
            "final_val_loss": [v for v in self.val_loss_history if v is not None][-1] if any(v is not None for v in self.val_loss_history) else None,
            "best_train_loss": min(self.train_loss_history) if self.train_loss_history else None,
            "best_val_loss": min(v for v in self.val_loss_history if v is not None) if any(v is not None for v in self.val_loss_history) else None,
            "best_epoch": self.best_epoch,
            "best_loss": self.best_loss
        }
        with open(loss_json_path, 'w') as f:
            json.dump(loss_dict, f, indent=2)
        print(f"Saved loss history (JSON) to {loss_json_path}")

        # Save a summary text file
        summary_path = os.path.join(self.checkpoint_dir, "training_summary.txt")
        with open(summary_path, 'w') as f:
            f.write("Training Summary\n")
            f.write("=" * 50 + "\n")
            f.write(f"Total Epochs Trained: {len(self.epoch_history)}\n")
            f.write(f"Final Train Loss: {loss_dict['final_train_loss']:.6f}\n" if loss_dict['final_train_loss'] else "")
            f.write(f"Final Val Loss: {loss_dict['final_val_loss']:.6f}\n" if loss_dict['final_val_loss'] else "")
            f.write(f"Best Train Loss: {loss_dict['best_train_loss']:.6f} (Epoch {self.best_epoch})\n" if loss_dict['best_train_loss'] else "")
            f.write(f"Best model saved as: best_model.pt\n")
            f.write(f"Best Val Loss: {loss_dict['best_val_loss']:.6f}\n" if loss_dict['best_val_loss'] else "")
        print(f"Saved training summary to {summary_path}")
