import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
import argparse
from dataclasses import dataclass
from collections import deque
import gymnasium as gym
import scipy.signal

from tasks import point_mass
from shimmy.registration import DM_CONTROL_SUITE_ENVS

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", type=str, default="Walker2d-v4")
    parser.add_argument("--max-episode-steps", type=int, default=300)
    parser.add_argument("--total-timesteps", type=int, default=100000)
    parser.add_argument("--controller-starts", type=int, default=15000)
    parser.add_argument("--latent-dim", type=int, default=256, help="Dimension of the latent space from the loaded encoder. Must match the encoder's output.")
    parser.add_argument("--horizon", type=int, default=3)
    parser.add_argument("--learning-rate", type=float, default=3e-4)
    parser.add_argument("--lr-pi", type=float, default=3e-4) 
    parser.add_argument("--action-repeats", type=int, default=1)
    parser.add_argument("--action-noise", type=float, default=0.1)
    parser.add_argument("--perturb-starts", type=int, default=40000)
    parser.add_argument("--step-perturb-scale", type=float, default=0.)
    parser.add_argument("--step-perturb-steps", type=int, default=10000)
    parser.add_argument("--slow-noise-scale", type=float, default=0.0)
    parser.add_argument("--slow-noise-tau", type=float, default=20000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--restart-from", type=str, default=None)
    parser.add_argument("--model-path", type=str, default="model.pt", help="Path to the pre-trained model containing policy, encoder, and forward model")
    return parser.parse_args()

@dataclass
class Trajectory:
    states: torch.Tensor  # [T, state_dim]
    total_actions: torch.Tensor  # [T, action_dim]
    next_states: torch.Tensor  # [T, state_dim]
    baseline_actions: torch.Tensor  # [T, action_dim]
    control_actions: torch.Tensor   # [T, action_dim]

class Controller(nn.Module):
    def __init__(
        self, 
        obs_dim: int,
        latent_dim: int,
        action_dim: int,
        lr_pi: float = 3e-4,
        horizon: int = 3,
        device: str = "cuda",
        encoder: nn.Module = None,
        forward_model: nn.Module = None,
    ):
        super().__init__()
        self.device = device
        self.horizon = horizon
        
        print("Using provided encoder")
        self.encoder = encoder
        self.encoder_optimizer = None

        self.pi_c = nn.Linear(latent_dim, action_dim).to(device)
        self.forward_model = forward_model.to(device)
        
        nn.init.xavier_uniform_(self.pi_c.weight, gain=0.001)
        nn.init.zeros_(self.pi_c.bias)
        
        self.pi_optimizer = torch.optim.Adam(self.pi_c.parameters(), lr=lr_pi)

    def update_controller(self, states, baseline_actions, ctrl_actions) -> float:
        total_loss = 0
        n_updates = 0
        
        states = states.to(self.device)
        latents = self.encoder(states).detach()
        baseline_actions = baseline_actions.to(self.device).detach()
        
        pred_latents = latents[:-self.horizon]
        
        # aux_action is a zero tensor with requires_grad=True. It's added to baseline_actions
        # for the forward model input. Its gradient is used as the target for pi_c.
        aux_action = torch.zeros_like(baseline_actions, requires_grad=True)
        
        controller_loss = 0
        # The multi-step rollout for calculating controller loss, using the forward model
        # over a horizon, is consistent with techniques used in methods like TD-MPC2.
        for h in range(self.horizon):
            current_total_action = baseline_actions[h:-self.horizon+h] + aux_action[h:-self.horizon+h]
            
            pred_latents = self.forward_model(pred_latents, current_total_action)
            target_latents = latents[h+1:h+1+len(pred_latents)]
            
            controller_loss += F.mse_loss(pred_latents, target_latents) * 0.5**h
            
        controller_loss /= self.horizon
        
        controller_loss.backward()
        
        self.pi_optimizer.zero_grad()
        ctrl_actions.backward(gradient=-aux_action.grad)
        self.pi_optimizer.step()
        
        total_loss += controller_loss.item()
        n_updates += 1
        
        return total_loss / n_updates if n_updates > 0 else 0

    def act(self, state: torch.Tensor) -> torch.Tensor:
        return self.pi_c(state)

class PerturbedEnv(gym.Wrapper):
    def __init__(
        self,
        env: gym.Env,
        action_repeats: int = 1,
        perturb_starts: int = 0,
        step_perturb_scale: float = 0.0,
        step_perturb_steps: int = 10000,
        slow_noise_scale: float = 0.0,
        slow_noise_tau: float = 20000,
        noise_size: int = 200000,
        action_noise: float = 0.0,
        device: str = "cuda"
    ):
        super().__init__(env)
        self.device = device
        self.action_repeats = action_repeats
        self.action_dim = np.prod(env.action_space.shape)
        
        self.action_noise = action_noise

        self.perturb_starts = perturb_starts
        self.step_perturb_scale = step_perturb_scale
        self.step_perturb_steps = step_perturb_steps
        self.perturb_factor = torch.ones(self.action_dim)
        self.step_count = 0
        
        self.slow_noise_scale = slow_noise_scale
        if slow_noise_scale > 0:
            self.slow_perturb = torch.zeros(noise_size, self.action_dim)
            for i in range(self.action_dim):
                self.slow_perturb[:,i] = torch.from_numpy(
                    self.slow_noise_scale * self._generate_lowpass_noise(
                        tau=slow_noise_tau, 
                        size=noise_size
                    )
                )
            print(f"Slow perturb std: {self.slow_perturb.std(0).numpy()}")

        self.episode_actions = []
    
    def _generate_lowpass_noise(self, tau: float, size: int, dt: float = 1.0):
        """Generate low-pass filtered noise"""
        # from scipy.signal import butter, lfilter, resample # This will be removed if present
        
        tau_ref = 1000
        ratio = tau / tau_ref
        
        cutoff_ref = 1./tau_ref
        fs = 1./dt
        size_ref = int(size / ratio)
        
        # Generate and filter noise
        b, a = scipy.signal.butter(5, cutoff_ref, fs=fs, btype='low', analog=False)
        y_ref = scipy.signal.lfilter(b, a, np.random.randn(size_ref)) * np.sqrt(0.5 * fs/cutoff_ref)
        
        return scipy.signal.resample(y_ref, size)
    
    def reset(self, **kwargs):
        """Reset environment and buffer"""
        obs, info = self.env.reset(**kwargs)
        self.episode_actions = []
        return obs, info
    
    def step(self, action):
        """Step environment with perturbations"""
        self.step_count += self.action_repeats
        
        if torch.is_tensor(action):
            action = action.cpu().numpy()
        
        self.episode_actions.append(action.copy())
        
        if self.perturb_starts > 0 and self.step_count < self.perturb_starts:
            perturb_factor = torch.ones(self.action_dim)
        else:
            if self.step_perturb_scale > 0:
                if self.step_count % (2 * self.step_perturb_steps) == 0:
                    self.perturb_factor = torch.ones(self.action_dim)
                elif self.step_count % (2 * self.step_perturb_steps) == self.step_perturb_steps:
                    self.perturb_factor = torch.FloatTensor(self.action_dim).uniform_(
                        1 - self.step_perturb_scale, 
                        1 + self.step_perturb_scale
                )
            
            if self.slow_noise_scale > 0:
                perturb_factor = self.perturb_factor + self.slow_perturb[self.step_count % len(self.slow_perturb)]
            else:
                perturb_factor = self.perturb_factor
        
        perturbed_action = action * perturb_factor.numpy()

        if self.action_noise > 0:
            perturbed_action = perturbed_action * (1 + np.random.randn(self.action_dim) * self.action_noise)
        
        # Step environment
        reward = 0.0
        terminated = False
        truncated = False
        for i in range(self.action_repeats):
            obs, reward_, terminated_, truncated_, info = self.env.step(perturbed_action)
            reward += reward_
            terminated = terminated or terminated_
            truncated = truncated or truncated_
            if terminated or truncated:
                break
        
        info.update({
            'perturb_factor': perturb_factor.numpy(),
            'original_action': action,
            'perturbed_action': perturbed_action
        })
        
        return obs, reward, terminated, truncated, info
    
    def get_episode_actions(self) -> np.ndarray:
        """Return actions from current episode"""
        return np.array(self.episode_actions)
    

def make_env(args):
    """Create environment with monitor wrapper"""
    env = gym.make(args.env_id, render_mode="rgb_array")
    env = gym.wrappers.FlattenObservation(env)
    env = gym.wrappers.NormalizeObservation(env)
    env._max_episode_steps = args.max_episode_steps
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = PerturbedEnv(
        env,
        action_repeats=args.action_repeats,
        action_noise=args.action_noise,
        perturb_starts=args.perturb_starts,
        step_perturb_scale=args.step_perturb_scale,
        step_perturb_steps=args.step_perturb_steps,
        slow_noise_scale=args.slow_noise_scale,
        slow_noise_tau=args.slow_noise_tau,
        device=args.device
    )

    if "point_mass" in args.env_id:
        print("Setting up point_mass env")
        # This line is removed.

    return env

def train_control(env, controller, baseline_policy, args):
    """Main training loop"""
    
    obs_dim = np.prod(env.observation_space.shape)
    action_dim = np.prod(env.action_space.shape)

    # Training loop
    obs, _ = env.reset()
    episode_return = 0
    episode_length = 0
    episode_data = []
    global_step = 0
    start_time = time.time()

    while global_step < args.total_timesteps:

        obs_torch = torch.FloatTensor(obs).to(args.device)
        base_action = baseline_policy(obs_torch.unsqueeze(0))
        if global_step < args.controller_starts:
            control_action = torch.zeros_like(base_action)
            total_action = base_action
        else:
            latent = controller.encoder(obs_torch)
            control_action = controller.act(latent)
            total_action = base_action + control_action.detach()

        # Step environment
        next_obs, reward, terminated, truncated, info = env.step(total_action.detach().cpu().numpy())
        episode_return += reward
        episode_length += args.action_repeats
        global_step += args.action_repeats

        # Store transition
        episode_data.append(Trajectory(
            states=obs_torch,
            total_actions=total_action,
            next_states=torch.FloatTensor([next_obs]),
            baseline_actions=base_action,
            control_actions=control_action,
        ))

        # Episode end
        if terminated or truncated:
            # Combine episode data
            episode_data = Trajectory(
                states=torch.cat([t.states.unsqueeze(0) for t in episode_data]),
                total_actions=torch.cat([t.total_actions.unsqueeze(0) for t in episode_data]),
                next_states=torch.cat([t.next_states.unsqueeze(0) for t in episode_data]),
                baseline_actions=torch.cat([t.baseline_actions.unsqueeze(0) for t in episode_data]),
                control_actions=torch.cat([t.control_actions.unsqueeze(0) for t in episode_data]),
            )

            if global_step > args.controller_starts:
                ctrl_loss = controller.update_controller(episode_data.states, episode_data.baseline_actions, episode_data.control_actions)
                print(f"Step: {global_step}, Train/controller_loss: {ctrl_loss:.4f}")

            # Print metrics
            print(f"Step: {global_step}, Return: {episode_return:.2f}, Length: {episode_length}, SPS: {global_step / (time.time() - start_time):.2f}")

            # Reset for next episode
            obs, _ = env.reset()
            episode_return = 0
            episode_length = 0
            episode_data = []

        else:
            obs = next_obs

    return controller

def main():
    args = parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    env = make_env(args)

    # --- BEGIN MODEL LOADING ---
    print(f"Loading model and components from: {args.model_path}")
    try:
        model_checkpoint = torch.load(args.model_path, map_location=args.device)
    except FileNotFoundError:
        print(f"Error: Model file not found at {args.model_path}. Please provide a valid path.")
        exit(1)
    except Exception as e:
        print(f"Error loading model: {e}")
        exit(1)

    # IMPORTANT: Adjust the keys based on how your model checkpoint is structured.
    # These are common examples; your structure might differ.
    try:
        baseline_policy = model_checkpoint['policy']       # Expected: callable policy(observation_tensor) -> action_tensor
        encoder = model_checkpoint['encoder']            # Expected: nn.Module encoder(observation_tensor) -> latent_tensor
        forward_model = model_checkpoint['forward_model']  # Expected: nn.Module forward_model(latent_tensor, action_tensor) -> next_latent_tensor
    except KeyError as e:
        print(f"Error: Key {e} not found in model checkpoint. Ensure your model file at {args.model_path} contains 'policy', 'encoder', and 'forward_model'.")
        exit(1)

    # Ensure models are on the correct device (if not handled by map_location or if they are more complex objects)
    if isinstance(encoder, nn.Module):
        encoder = encoder.to(args.device)
    if isinstance(forward_model, nn.Module):
        forward_model = forward_model.to(args.device)
    # baseline_policy might be a function or an nn.Module. If nn.Module, move to device:
    # if isinstance(baseline_policy, nn.Module):
    #     baseline_policy = baseline_policy.to(args.device)

    print("Successfully loaded model components.")
    # --- END MODEL LOADING ---
  
    obs_dim = np.prod(env.observation_space.shape)
    action_dim = np.prod(env.action_space.shape)

    # Latent dimension should be determined by the output of your loaded encoder.
    # For this example, we rely on args.latent_dim, assuming it's set correctly.
    # Ensure args.latent_dim matches the output dimension of your loaded encoder.
    latent_dim = args.latent_dim

    print("Observation dim:", obs_dim)
    print("Action dim:", action_dim)
    print("Latent dim:", latent_dim) # This will now reflect the true latent_dim being used

    # Instantiate ForwardModel directly as it's assumed to be provided
    # forward_model = nn.Linear(latent_dim, latent_dim).to(args.device) # This line will be commented out

    controller = Controller(
        obs_dim=obs_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
        lr_pi=args.lr_pi,
        horizon=args.horizon,
        device=args.device,
        encoder=encoder,
        forward_model=forward_model,
    )

    if args.restart_from is not None:
        controller.load_state_dict(torch.load("models/" + args.restart_from))

    try:
        train_control(env, controller, baseline_policy, args)
                    
    except KeyboardInterrupt:
        print("Training interrupted")
    finally:
        model_save_name = f"models/controller_cleaned_example.pt"
        print("Saving model to:", model_save_name)
        torch.save(controller.state_dict(), model_save_name)

        env.close()

if __name__ == "__main__":
    main()