import copy
import gym
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from rlf.algos.il.base_il import BaseILAlgo
from ceil.buffer import TrajectoryBuffer
from imitation.util import logger as imit_logger

class CEIL(BaseILAlgo):
    """
    Context-Enhanced Imitation Learning (CEIL)
    Reformatted to align with the Proximity-IRL structure.
    """
    def __init__(
        self,
        source_env_name: str,
        target_env_name: str,
        mode: str,
        demo: str,
        demonstrations,
        actor_cls,
        encoder_cls,
        disc_cls,
        cdmi_cls,
        alpha,
        traj_batch_size: int,
        n_steps: int,
        seed: int,
        logger_folder: str,
        device: torch.device = "cuda",
    ):
        super().__init__(demonstrations=demonstrations)
        
        # Logger
        self.logger = imit_logger.configure(
            folder=logger_folder,
            format_strs=["stdout", "log", "csv", "tensorboard"]
        )
        
        # Environment setup
        self.source_env = gym.make(source_env_name)
        self.target_env = gym.make(target_env_name)
        self.traj_batch_size = traj_batch_size
        self.n_steps = n_steps
        self.device = device
        
        # Training hyperparameters
        self.alpha = alpha
        self.mode = mode
        self.demo = demo
        
        # Replay Buffer
        self.replay_buffer = TrajectoryBuffer(1000, 128, 4)
        
        # Initialize Models
        self.actor = actor_cls()
        self.encoder = encoder_cls()
        self.disc = disc_cls()
        self.cdmi = cdmi_cls()
        
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=3e-4)
        self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=3e-4)
        self.cdmi_opt = torch.optim.Adam(self.cdmi.parameters(), lr=3e-4)
        
        self.context = torch.nn.Parameter(self.init_context(), requires_grad=True)
        self.context_opt = torch.optim.Adam([self.context], lr=3e-4)
        self.context.to(self.device)
        
        self.seed = seed
        self._set_seed()
        
        self._global_step = 0
        self._initialize_training_state()
    
    def _initialize_training_state(self):
        self.stats_len_mean = 0
        self._last_obs = self.source_env.reset()
    
    def _set_seed(self):
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        self.source_env.seed(self.seed)
        self.target_env.seed(self.seed)
    
    def init_context(self):
        """Initialize context embeddings."""
        with torch.no_grad():
            demo_samples, demo_ts, _ = self.replay_buffer.sample(self.traj_batch_size)
            demo_samples = torch.as_tensor(demo_samples).to(self.device)
            demo_ts = torch.as_tensor(demo_ts).to(self.device).long()
            context = self.encoder(demo_samples, demo_ts)[0].mean(0).cpu()
        return context
    
    def full_train(self, total_steps: int):
        """Main training loop."""
        rounds = total_steps // self.n_steps
        
        for round_idx in tqdm(range(rounds), desc="Training Rounds"):
            self.policy_rollout(self.n_steps)
            self.policy_evaluate()
            self._global_step += 1
            self._update_models()
    
    def policy_rollout(self, steps: int):
        """Perform policy rollout and collect trajectories."""
        self.actor.eval()
        obs = self._last_obs
        
        for _ in range(steps):
            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs).to(self.device)
                actions = self.actor(obs_tensor, self.context)[0].cpu().numpy()
            obs, _, done, _ = self.source_env.step(actions)
            if done:
                obs = self.source_env.reset()
        self._last_obs = obs
    
    def policy_evaluate(self):
        """Evaluate policy performance."""
        self.actor.eval()
        obs = self.target_env.reset()
        total_reward = 0
        
        for _ in range(1000):
            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs).to(self.device)
                actions = self.actor(obs_tensor, self.context)[0].cpu().numpy()
            obs, reward, done, _ = self.target_env.step(actions)
            total_reward += reward
            if done:
                obs = self.target_env.reset()
        
        self.logger.record("evaluation/total_reward", total_reward)
    
    def _update_models(self):
        """Update actor, encoder, discriminator, and context."""
        for _ in range(self.traj_batch_size):
            demo_samples, demo_ts, _ = self.replay_buffer.sample(self.traj_batch_size)
            demo_samples = torch.as_tensor(demo_samples).to(self.device)
            demo_ts = torch.as_tensor(demo_ts).to(self.device).long()
            
            loss_actor = F.mse_loss(self.actor(demo_samples, self.context)[0], demo_samples)
            loss_encoder = F.mse_loss(self.encoder(demo_samples, demo_ts)[0], demo_samples)
            loss_disc = F.binary_cross_entropy_with_logits(self.disc(demo_samples), torch.ones_like(demo_samples))
            
            self.actor_opt.zero_grad()
            loss_actor.backward()
            self.actor_opt.step()
            
            self.encoder_opt.zero_grad()
            loss_encoder.backward()
            self.encoder_opt.step()
            
            self.disc_opt.zero_grad()
            loss_disc.backward()
            self.disc_opt.step()
        
        self.logger.record("loss/actor", loss_actor.item())
        self.logger.record("loss/encoder", loss_encoder.item())
        self.logger.record("loss/disc", loss_disc.item())
    
    def get_policy(self):
        return self.actor, self.context

    def set_demonstrations(self, demo_trajs):
        """Load demonstration trajectories."""
        self.demo_buffer = TrajectoryBuffer(len(demo_trajs), 128, 4, is_demo=True)
        self.demo_buffer.store(demo_trajs)
        self.replay_buffer.store(demo_trajs, is_demo=False)