import torch
import pytorch_lightning as pl
from base_experiment import BaseExperiment
from experiments.kitchen_experiment_hydra import KitchenExperimentHydra
from experiments.pusht_experiment_hydra import PushTExperimentHydra
from experiments.mimic_experiment_hydra import MimicExperimentHydra



def create_experiment(cfg):
    """Create appropriate experiment instance based on config."""
    env_name = cfg.env.name
    
    if env_name == "pusht":
        return PushTExperimentHydra(cfg)
    elif env_name == "kitchen":
        return KitchenExperimentHydra(cfg)
    elif env_name == "mimic":
        return MimicExperimentHydra(cfg)
    else:
        raise ValueError(f"Unknown environment: {env_name}")

class FlowMatchingLightningModule(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        
        # Convert Hydra config to serializable dict
        from omegaconf import OmegaConf
        config_dict = OmegaConf.to_container(cfg, resolve=True)
        
        # Filter non-serializable values
        filtered_config = {}
        for k, v in config_dict.items():
            if isinstance(v, (str, int, float, bool, list, dict, type(None))):
                filtered_config[k] = v
            else:
                filtered_config[k] = str(v)

        print("Filtered config for hyperparameters:", filtered_config)
        
        self.save_hyperparameters(filtered_config)
        # config_dict = {
        #     "lr": 1,
        #     "batch_size": 1,
        # }
        # print(f"hparams (2): {self.hparams}")
        # self.hparams=config_dict
        # self.save_hyperparameters(config_dict)

        # print("self.hparams:", self.hparams)

        
        self.cfg = cfg
        self.experiment = create_experiment(cfg)
        self.setup_experiment()
        
    def setup_experiment(self):
        """Setup experiment components for current mode only."""
        # Initialize model and dataset
        self.model = self.experiment.setup_model()
        self.experiment.model = self.model  # Pass model reference to experiment
        self.dataloader = self.experiment.setup_dataset()
        
        # Validate training mode
        if self.cfg.execution.mode not in ['fm_train', 'mle_finetune', 'res_finetune']:
            raise ValueError(f"Invalid training mode: {self.cfg.execution.mode}. Must be 'fm_train', 'mle_finetune' or 'res_finetune'")
        
        # Initialize EMA model for both modes
        # For res_finetune mode: create pretrained model and freeze it
        if self.cfg.execution.mode == 'res_finetune':
            # Create and freeze pretrained model
            self.pretrained_model = self.experiment.setup_model()
            
            # Load pretrained weights if available
            checkpoint = self.cfg.execution.pretrain_checkpoint
            if checkpoint:
                state_dict = torch.load(checkpoint)
                self.pretrained_model.load_state_dict(state_dict)
            
            # Freeze pretrained model
            for param in self.pretrained_model.parameters():
                param.requires_grad = False
                
        # Initialize EMA model with trainable parameters
        from diffusers.training_utils import EMAModel
        self.experiment.ema = EMAModel(
            parameters=self.model.parameters(),  # Always points to trainable model
            power=self.cfg.training.ema_power if hasattr(self.cfg.training, 'ema_power') else 0.75
        )
        
        # Initialize components based on mode
        if self.cfg.execution.mode == 'fm_train':
            from torchcfm.conditional_flow_matching import (
                ConditionalFlowMatcher,
                ExactOptimalTransportConditionalFlowMatcher,
                TargetConditionalFlowMatcher,
                VariancePreservingConditionalFlowMatcher
            )
            
            # Check for ot_model in both root and training config
            ot_model = getattr(self.cfg, 'ot_model', None) or getattr(self.cfg.training, 'ot_model', None)
            if not ot_model:
                raise ValueError("ot_model must be specified in config (either at root or under training)")
            
            # Check for sigma in both root and training config
            sigma = getattr(self.cfg.execution, 'sigma', 0) 
            
            if ot_model == "otcfm":
                self.flow_matcher = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
            elif ot_model == "icfm":
                self.flow_matcher = ConditionalFlowMatcher(sigma=sigma)
            elif ot_model == "fm":
                self.flow_matcher = TargetConditionalFlowMatcher(sigma=sigma)
            elif ot_model == "si":
                self.flow_matcher = VariancePreservingConditionalFlowMatcher(sigma=sigma)
            else:
                raise NotImplementedError(
                    f"Unknown model {ot_model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
                )
            print(f"Initialized {ot_model} flow_matcher with sigma:", sigma)
            print(f"Flow matcher type: {type(self.flow_matcher)}")
        
        # Initialize optimizer if needed
        if hasattr(self.experiment, 'optimizer'):
            self.optimizer = self.experiment.optimizer
        
    def forward(self, x, t, global_cond=None):
        """Forward pass through the model."""
        return self.call_model(self.model, x, t, global_cond)
        
    def call_model(self, model, x, t, global_cond=None):
        """
        Helper to call any model with proper argument handling
        Explicitly checks for global_cond support to avoid silent fallbacks
        """
        # Handle models with noise_pred_net attribute
        if hasattr(model, 'noise_pred_net'):
            return model['noise_pred_net'](x, t, global_cond=global_cond)
            
        # Check if model supports global_cond parameter
        try:
            # First, try to call with global_cond
            return model(x, t, global_cond=global_cond)
        except TypeError as e:
            # If it fails due to unexpected argument, try without
            if "unexpected keyword argument 'global_cond'" in str(e):
                print(f"Note: Model {type(model).__name__} does not support global_cond parameter")
                return model(x, t)
            else:
                # Re-raise other errors
                raise e
        
    def training_step(self, batch, batch_idx):
        """Training step that handles all modes."""
        if self.cfg.execution.mode == 'fm_train':
            return self.flow_matching_step(batch)
        elif self.cfg.execution.mode == 'mle_finetune':
            return self.mle_training_step(batch, batch_idx)
        elif self.cfg.execution.mode == 'res_finetune':
            return self.residual_training_step(batch, batch_idx)
            
    def flow_matching_step(self, batch):
        """Flow Matching training step."""
        # Process batch to get observations and actions
        obs_cond, x_traj = self.experiment.process_batch(batch)
        
        # Flow matching training
        x_traj = x_traj.float()
        x0 = torch.randn(x_traj.shape, device=self.device)
        timestep, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x_traj)
        
        # Forward pass
        vt = self(xt, timestep, global_cond=obs_cond)
        
        # Compute loss
        loss = torch.mean((vt - ut) ** 2)
        
        # Log metrics
        self.log("train_loss", float(loss.item()), prog_bar=True)
        self.log("epoch", int(self.current_epoch), prog_bar=False)
        return loss
        
    def switch_training_mode(self, mode: str):
        """Switch between 'fm_train' and 'mle' training modes."""
        assert mode in ['fm_train', 'mle_finetune'], "Mode must be either 'fm_train' or 'mle'"
        self.cfg.execution.mode = mode
        print(f"Switched to {mode.upper()} training mode")

    def on_train_epoch_end(self):
        """统一EMA更新逻辑 - 每个epoch更新，定期拷贝参数"""
        if not hasattr(self.experiment, 'ema') or not self.trainer.is_global_zero:
            return
            
        model = self.model.module if hasattr(self.model, 'module') else self.model
        
        # 每个epoch都更新EMA参数
        self.experiment.ema.step(model.parameters())
        
        # 定期将EMA参数拷贝回模型
        if (self.current_epoch + 1) % self.cfg.training.ema_power_epoch_interval == 0:
            self.experiment.ema.copy_to(model.parameters())
        
        # 训练结束时保存检查点
        if (self.current_epoch + 1) == self.trainer.max_epochs:
            # import os
            # os.makedirs('checkpoint', exist_ok=True)
            # torch.save({
            #     'model': model.state_dict(),
            #     'epoch': self.current_epoch,
            #     'loss': self.trainer.callback_metrics.get('train_loss', float('nan')),
            # }, f'checkpoint/ema_{self.current_epoch:05d}.pth')
            pass

    def mle_training_step(self, batch, batch_idx):
        """MLE training step matching flow_kitchen_ddp.py."""
        # Process batch to get observations and actions
        obs_cond, x_traj = self.experiment.process_batch(batch)
        
        # Generate random noise as starting point
        x_traj = x_traj.float()
        x0 = torch.randn(x_traj.shape, device=self.device)
        
        # Setup time points for ODE solver
        t = torch.linspace(0, 1, self.cfg.execution.solver.time_steps, device=self.device)
        # t = torch.tensor([0.0, 1.0], device=self.device)
        trajectory = self.basic_euler_ode_solver(x0, t, obs_cond)
        x1 = trajectory[-1]
        # Solve ODE to get final state directly
        # x1 = self.torchdyn_ode_solver(x0, t, obs_cond)
        
        # Compute MLE loss (direct reconstruction loss)
        loss = torch.mean((x1 - x_traj) ** 2)
        
        # Log metrics
        self.log("mle_train_loss", float(loss.item()), prog_bar=True)
        return loss

    def torchdyn_ode_solver(self, y0, t, obs_cond):
        """Torchdyn NeuralDE solver for MLE training, returning only the final state."""
        from torchdyn import NeuralODE
        
        # Define vector field function compatible with Torchdyn
        def vector_field(t, y):
            return self(y, t, global_cond=obs_cond)
        
        # Create NeuralDE model with specified solver
        model = NeuralODE(vector_field, solver='dopri5', atol=1e-4, rtol=1e-4)
        
        # Solve ODE and return final state
        return model(y0, t)[-1]

    def basic_euler_ode_solver(self, y0, t, obs_cond, model=None):
        """Basic Euler ODE solver with optional model selection"""
        model = model or self.model
        dt = t[1] - t[0]
        y = y0
        ys = [y0]
        for i in range(len(t) - 1):
            t_start = t[i]
            # Use our robust model calling function
            dy = self.call_model(model, y, t_start, obs_cond)
            y = y + dy * dt
            ys.append(y)
        return torch.stack(ys)
        
    def residual_training_step(self, batch, batch_idx):
        """Residual finetuning training step."""
        obs_cond, x_traj = self.experiment.process_batch(batch)
        x_traj = x_traj.float()
        
        # Generate initial noise
        x0 = torch.randn(x_traj.shape, device=self.device)
        
        # Generate intermediate solution with pretrained model (no gradients)
        with torch.no_grad():
            t = torch.linspace(0, 1, self.cfg.execution.solver.time_steps, device=self.device)
            pretrained_solution = self.basic_euler_ode_solver(x0, t, obs_cond, model=self.pretrained_model)
            intermediate_state = pretrained_solution[-1]
        
        # Generate final solution with trainable model
        final_solution = self.basic_euler_ode_solver(intermediate_state, t, obs_cond, model=self.model)
        x1 = final_solution[-1]
        
        # Compute loss and optimize only residual model
        loss = torch.mean((x1 - x_traj) ** 2)
        self.log("res_train_loss", float(loss.item()), prog_bar=True)
        return loss


    def configure_optimizers(self):
        """Configure optimizers with unified settings for all training modes."""
        if not hasattr(self.cfg.training, 'learning_rate'):
            raise ValueError("training.learning_rate must be specified in config")
            
        # In all modes, we optimize the parameters of self.model
        params = self.model.parameters()
            
        optimizer = torch.optim.AdamW(
            params=params,
            lr=self.cfg.training.learning_rate,
            weight_decay=self.cfg.training.weight_decay if hasattr(self.cfg.training, 'weight_decay') else 0.0
        )
        
        from diffusers.optimization import get_scheduler
        scheduler = get_scheduler(
            name='cosine',
            optimizer=optimizer,
            num_warmup_steps=self.cfg.training.warmup_steps if hasattr(self.cfg.training, 'warmup_steps') else 0,
            num_training_steps=len(self.train_dataloader()) * (self.cfg.training.epochs)
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1
            }
        }
        
    def train_dataloader(self):
        """Return the training dataloader."""
        if not hasattr(self.experiment, 'dataloader') or self.experiment.dataloader is None:
            self.experiment.dataloader = self.experiment.setup_dataset()
        return self.experiment.dataloader

def create_lightning_module(cfg):
    """Create and return a LightningModule instance."""
    return FlowMatchingLightningModule(cfg)
