import copy

from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

from all2.approximation import FixedTarget
from all2.bodies import DeepmindAtariBody
from all2.logging import DummyLogger
from all2.memory import PrioritizedReplayBuffer
from all2.optim import LinearScheduler
from all2.presets.builder import PresetBuilder
from all2.presets.preset import Preset
from all2 import nn
from model import JointQDist, jointdqn_backbone_constructor, \
jointdqn_mixing_head_constructor, jointdqn_mean_head_constructor, \
jointdqn_covariance_head_constructor, jointdqn_cartpole_backbone_constructor, \
jointdqn_cartpole_mixing_head_constructor, jointdqn_cartpole_mean_head_constructor, \
jointdqn_cartpole_covariance_head_constructor
from agent import JointDQN, JointDQNTestAgent
from mv import JointDQNTestAgentMV

default_hyperparameters = {
    "discount_factor": 0.99,
    # Adam optimizer settings
    "lr": 1e-4,
    "eps": 1e-7,
    # Training settings
    "minibatch_size": 32,
    "update_frequency": 4,
    "target_update_frequency": 1000,
    # Replay buffer settings
    "replay_start_size": 80000,
    "replay_buffer_size": 1000000,
    # Explicit exploration
    "initial_exploration": 1.0, # 1.0
    "final_exploration": 0.01, # 0.01
    "final_exploration_step": 250000,
    "test_exploration": 0.001,
    # Mixture of Gaussians
    "n_mixture": 3,
    # Model construction
    "backbone_constructor": jointdqn_backbone_constructor,
    "mixing_head_constructor": jointdqn_mixing_head_constructor,
    "mean_head_constructor": jointdqn_mean_head_constructor,
    "covariance_head_constructor": jointdqn_covariance_head_constructor,
    "n_update_actions": 2, # 2
    "rand_actions": True,
    "initial_q": 0.5
}


default_hyperparameters_cartpole = {
    "discount_factor": 0.99,
    # Adam optimizer settings
    "lr": 1e-3,
    # Training settings
    "minibatch_size": 128,
    "update_frequency": 1,
    "target_update_frequency": 250,
    # Replay buffer settings
    "replay_start_size": 1000, # WAS 100000
    "replay_buffer_size": 20000,
    # Explicit exploration
    "initial_exploration": 1.0, # 1.0
    "final_exploration": 0.01, # 0.01
    "final_exploration_step": 10000,
    "test_exploration": 0.001,
    # Mixture of Gaussians
    "n_mixture": 1,
    # Model construction
    "backbone_constructor": jointdqn_cartpole_backbone_constructor,
    "mixing_head_constructor": jointdqn_cartpole_mixing_head_constructor,
    "mean_head_constructor": jointdqn_cartpole_mean_head_constructor,
    "covariance_head_constructor": jointdqn_cartpole_covariance_head_constructor,
    "n_update_actions": 2, # 2
    "rand_actions": True,
    "initial_q": 0.5
}


class JointDQNAtariPreset(Preset):
    """
    Joint DQN Atari preset.

    Args:
        env (all.environments.AtariEnvironment): The environment for which to construct the agent.
        name (str): A human-readable name for the preset.
        device (torch.device): The device on which to load the agent.

    Keyword Args:
        discount_factor (float): Discount factor for future rewards.
        lr (float): Learning rate for the Adam optimizer.
        eps (float): Stability parameters for the Adam optimizer.
        minibatch_size (int): Number of experiences to sample in each training update.
        update_frequency (int): Number of timesteps per training update.
        target_update_frequency (int): Number of timesteps between updates the target network.
        replay_start_size (int): Number of experiences in replay buffer when training begins.
        replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
        initial_exploration (float): Initial probability of choosing a random action,
            decayed over course of training.
        final_exploration (float): Final probability of choosing a random action.
        final_exploration_step (int): The step at which exploration decay is finished
        test_exploration (float): The exploration rate of the test Agent
        n_mixture (int): The number of mixture components in the Gaussian mixture.
        model_constructor (function): The function used to construct the neural model.
    """


    def __init__(self, env, name, device, **hyperparameters):
        super().__init__(name, device, hyperparameters)
        self.n_mixture = hyperparameters["n_mixture"]
        self.n_actions = env.action_space.n
        self.backbone = hyperparameters["backbone_constructor"]().to(device)
        self.mixing_head = hyperparameters["mixing_head_constructor"](self.n_mixture).to(device)
        self.mean_head = hyperparameters["mean_head_constructor"](self.n_mixture, self.n_actions).to(device)
        self.covariance_head = hyperparameters["covariance_head_constructor"](self.n_actions).to(device)
        self.model = nn.ModuleList((self.backbone, self.mixing_head, self.mean_head, self.covariance_head))

    def agent(self, logger=DummyLogger(), train_steps=float("inf")):
        n_updates = (
            train_steps - self.hyperparameters["replay_start_size"]
        ) / self.hyperparameters["update_frequency"]

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters["lr"],
            eps=1.5e-4,
            fused=True
        )

        q = JointQDist(
            self.backbone,
            self.mixing_head,
            self.mean_head,
            self.covariance_head,
            optimizer,
            self.n_actions,
            self.n_mixture,
            target=FixedTarget(self.hyperparameters["target_update_frequency"], self.n_mixture, self.n_actions),
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            logger=logger,
            clip_grad = 10.
        )
        
        
        replay_buffer = PrioritizedReplayBuffer(
            self.hyperparameters["replay_buffer_size"],
            device=self.device
        )
                

        return DeepmindAtariBody(
            JointDQN(
                q,
                replay_buffer,
                self.n_actions,
                exploration=LinearScheduler(
                self.hyperparameters["initial_exploration"],
                self.hyperparameters["final_exploration"],
                self.hyperparameters["replay_start_size"],
                self.hyperparameters["final_exploration_step"]
                - self.hyperparameters["replay_start_size"],
                name="exploration",
                logger=logger,
            ),
                discount_factor=self.hyperparameters["discount_factor"],
                minibatch_size=self.hyperparameters["minibatch_size"],
                replay_start_size=self.hyperparameters["replay_start_size"],
                update_frequency=self.hyperparameters["update_frequency"],
                logger=logger,
                n_update_actions=self.hyperparameters["n_update_actions"],
                rand_actions=self.hyperparameters["rand_actions"],
                update_q=LinearScheduler(
                    initial_value=self.hyperparameters["initial_q"],
                    final_value=0.05, # 0.05
                    decay_start=0,
                    decay_end=1.25e6,
                    name="q",
                    logger=logger,
                )
            ),
            lazy_frames=True,
            episodic_lives=True,
        )

    def test_agent(self):
        q_dist = JointQDist(
            copy.deepcopy(self.backbone),
            copy.deepcopy(self.mixing_head),
            copy.deepcopy(self.mean_head),
            copy.deepcopy(self.covariance_head),
            None,
            self.n_actions,
            self.n_mixture
        )
        return DeepmindAtariBody(
            JointDQNTestAgent(
                q_dist, self.hyperparameters["test_exploration"]
            ),
        )
    

class JointDQNCartpolePreset(Preset):
    """
    Joint DQN Atari preset.

    Args:
        env (all.environments.AtariEnvironment): The environment for which to construct the agent.
        name (str): A human-readable name for the preset.
        device (torch.device): The device on which to load the agent.

    Keyword Args:
        discount_factor (float): Discount factor for future rewards.
        lr (float): Learning rate for the Adam optimizer.
        eps (float): Stability parameters for the Adam optimizer.
        minibatch_size (int): Number of experiences to sample in each training update.
        update_frequency (int): Number of timesteps per training update.
        target_update_frequency (int): Number of timesteps between updates the target network.
        replay_start_size (int): Number of experiences in replay buffer when training begins.
        replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
        initial_exploration (float): Initial probability of choosing a random action,
            decayed over course of training.
        final_exploration (float): Final probability of choosing a random action.
        final_exploration_step (int): The step at which exploration decay is finished
        test_exploration (float): The exploration rate of the test Agent
        n_mixture (int): The number of mixture components in the Gaussian mixture.
        model_constructor (function): The function used to construct the neural model.
    """


    def __init__(self, env, name, device, **hyperparameters):
        super().__init__(name, device, hyperparameters)
        self.n_mixture = hyperparameters["n_mixture"]
        self.n_actions = env.action_space.n
        self.backbone = hyperparameters["backbone_constructor"]().to(device)
        self.mixing_head = hyperparameters["mixing_head_constructor"](self.n_mixture).to(device)
        self.mean_head = hyperparameters["mean_head_constructor"](self.n_mixture, self.n_actions).to(device)
        self.covariance_head = hyperparameters["covariance_head_constructor"](self.n_actions).to(device)
        self.model = nn.ModuleList((self.backbone, self.mixing_head, self.mean_head, self.covariance_head))


    def agent(self, logger=DummyLogger(), train_steps=float("inf")):
        n_updates = (
            train_steps - self.hyperparameters["replay_start_size"]
        ) / self.hyperparameters["update_frequency"]

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters["lr"],
            fused=True
        )

        q = JointQDist(
            self.backbone,
            self.mixing_head,
            self.mean_head,
            self.covariance_head,
            optimizer,
            self.n_actions,
            self.n_mixture,
            target=FixedTarget(self.hyperparameters["target_update_frequency"], self.n_mixture, self.n_actions),
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            logger=logger,
        )
        
        
        replay_buffer = PrioritizedReplayBuffer(
            self.hyperparameters["replay_buffer_size"],
            device=self.device
        )
                

        return JointDQN(
                q,
                replay_buffer,
                self.n_actions,
                exploration=LinearScheduler(
                self.hyperparameters["initial_exploration"],
                self.hyperparameters["final_exploration"],
                self.hyperparameters["replay_start_size"],
                self.hyperparameters["final_exploration_step"]
                - self.hyperparameters["replay_start_size"],
                name="exploration",
                logger=logger,
            ),
                discount_factor=self.hyperparameters["discount_factor"],
                minibatch_size=self.hyperparameters["minibatch_size"],
                replay_start_size=self.hyperparameters["replay_start_size"],
                update_frequency=self.hyperparameters["update_frequency"],
                logger=logger,
                n_update_actions=self.hyperparameters["n_update_actions"],
                rand_actions=self.hyperparameters["rand_actions"],
                update_q=LinearScheduler(
                    initial_value=self.hyperparameters["initial_q"],
                    final_value=0.01, # 0.05
                    decay_start=0,
                    decay_end=6250,
                    name="q",
                    logger=logger,
                )
            )

    def test_agent(self):
        q_dist = JointQDist(
            copy.deepcopy(self.backbone),
            copy.deepcopy(self.mixing_head),
            copy.deepcopy(self.mean_head),
            copy.deepcopy(self.covariance_head),
            None,
            self.n_actions,
            self.n_mixture
        )
        return JointDQNTestAgent(
                q_dist, self.hyperparameters["test_exploration"]
            )


class JointDQNAtariTestPreset(Preset):
    """
    Joint DQN Atari preset.

    Args:
        env (all.environments.AtariEnvironment): The environment for which to construct the agent.
        name (str): A human-readable name for the preset.
        device (torch.device): The device on which to load the agent.

    Keyword Args:
        discount_factor (float): Discount factor for future rewards.
        lr (float): Learning rate for the Adam optimizer.
        eps (float): Stability parameters for the Adam optimizer.
        minibatch_size (int): Number of experiences to sample in each training update.
        update_frequency (int): Number of timesteps per training update.
        target_update_frequency (int): Number of timesteps between updates the target network.
        replay_start_size (int): Number of experiences in replay buffer when training begins.
        replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
        initial_exploration (float): Initial probability of choosing a random action,
            decayed over course of training.
        final_exploration (float): Final probability of choosing a random action.
        final_exploration_step (int): The step at which exploration decay is finished
        test_exploration (float): The exploration rate of the test Agent
        n_mixture (int): The number of mixture components in the Gaussian mixture.
        model_constructor (function): The function used to construct the neural model.
    """


    def __init__(self, env, name, device, **hyperparameters):
        super().__init__(name, device, hyperparameters)
        self.n_mixture = hyperparameters["n_mixture"]
        self.n_actions = env.action_space.n
        self.backbone = hyperparameters["backbone_constructor"]().to(device)
        self.mixing_head = hyperparameters["mixing_head_constructor"](self.n_mixture).to(device)
        self.mean_head = hyperparameters["mean_head_constructor"](self.n_mixture, self.n_actions).to(device)
        self.covariance_head = hyperparameters["covariance_head_constructor"](self.n_actions).to(device)
        import torch
        
        # Pong
        preset = torch.load("/local/scratch/a/kayae/backup/jointdqn_Pong_2025-09-03_16:19:15_176668/preset.pt")

        # Boxing
        # preset = torch.load("/local/scratch/a/kayae/backup/jointdqn_Boxing_2025-09-08_09:44:26_806521/preset.pt")

        # Atlantis
        # preset = torch.load("/local/scratch/a/kayae/backup/jointdqn_Atlantis_2025-09-15_14:48:12_956094 copy/preset.pt")
        self.backbone.load_state_dict(preset.backbone.state_dict())
        self.mixing_head.load_state_dict(preset.mixing_head.state_dict())
        self.mean_head.load_state_dict(preset.mean_head.state_dict())
        self.covariance_head.load_state_dict(preset.covariance_head.state_dict())
        self.model = nn.ModuleList((self.backbone, self.mixing_head, self.mean_head, self.covariance_head))


    def agent(self, logger=DummyLogger(), train_steps=float("inf")):
        n_updates = (
            train_steps - self.hyperparameters["replay_start_size"]
        ) / self.hyperparameters["update_frequency"]

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters["lr"],
            eps=1.5e-4,
            fused=True
        )

        q = JointQDist(
            self.backbone,
            self.mixing_head,
            self.mean_head,
            self.covariance_head,
            optimizer,
            self.n_actions,
            self.n_mixture,
            target=FixedTarget(self.hyperparameters["target_update_frequency"], self.n_mixture, self.n_actions),
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            logger=logger,
            clip_grad = 10.
        )
        
        
        replay_buffer = PrioritizedReplayBuffer(
            self.hyperparameters["replay_buffer_size"],
            device=self.device
        )
                

        return DeepmindAtariBody(
            JointDQN(
                q,
                replay_buffer,
                self.n_actions,
                exploration=LinearScheduler(
                self.hyperparameters["initial_exploration"],
                self.hyperparameters["final_exploration"],
                self.hyperparameters["replay_start_size"],
                self.hyperparameters["final_exploration_step"]
                - self.hyperparameters["replay_start_size"],
                name="exploration",
                logger=logger,
            ),
                discount_factor=self.hyperparameters["discount_factor"],
                minibatch_size=self.hyperparameters["minibatch_size"],
                replay_start_size=self.hyperparameters["replay_start_size"],
                update_frequency=self.hyperparameters["update_frequency"],
                logger=logger,
                n_update_actions=self.hyperparameters["n_update_actions"],
                rand_actions=self.hyperparameters["rand_actions"],
                update_q=LinearScheduler(
                    initial_value=self.hyperparameters["initial_q"],
                    final_value=0.05, # 0.05
                    decay_start=0,
                    decay_end=1.25e6,
                    name="q",
                    logger=logger,
                )
            ),
            lazy_frames=True,
            episodic_lives=True,
        )

    def test_agent(self):
        q_dist = JointQDist(
            copy.deepcopy(self.backbone),
            copy.deepcopy(self.mixing_head),
            copy.deepcopy(self.mean_head),
            copy.deepcopy(self.covariance_head),
            None,
            self.n_actions,
            self.n_mixture
        )
        return DeepmindAtariBody(
            JointDQNTestAgentMV(
                q_dist, self.hyperparameters["test_exploration"]
            ),
        )


class JointDQNCartpolePreset(Preset):
    """
    Joint DQN Atari preset.

    Args:
        env (all.environments.AtariEnvironment): The environment for which to construct the agent.
        name (str): A human-readable name for the preset.
        device (torch.device): The device on which to load the agent.

    Keyword Args:
        discount_factor (float): Discount factor for future rewards.
        lr (float): Learning rate for the Adam optimizer.
        eps (float): Stability parameters for the Adam optimizer.
        minibatch_size (int): Number of experiences to sample in each training update.
        update_frequency (int): Number of timesteps per training update.
        target_update_frequency (int): Number of timesteps between updates the target network.
        replay_start_size (int): Number of experiences in replay buffer when training begins.
        replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
        initial_exploration (float): Initial probability of choosing a random action,
            decayed over course of training.
        final_exploration (float): Final probability of choosing a random action.
        final_exploration_step (int): The step at which exploration decay is finished
        test_exploration (float): The exploration rate of the test Agent
        n_mixture (int): The number of mixture components in the Gaussian mixture.
        model_constructor (function): The function used to construct the neural model.
    """


    def __init__(self, env, name, device, **hyperparameters):
        super().__init__(name, device, hyperparameters)
        self.n_mixture = hyperparameters["n_mixture"]
        self.n_actions = env.action_space.n
        self.backbone = hyperparameters["backbone_constructor"]().to(device)
        self.mixing_head = hyperparameters["mixing_head_constructor"](self.n_mixture).to(device)
        self.mean_head = hyperparameters["mean_head_constructor"](self.n_mixture, self.n_actions).to(device)
        self.covariance_head = hyperparameters["covariance_head_constructor"](self.n_actions).to(device)
        self.model = nn.ModuleList((self.backbone, self.mixing_head, self.mean_head, self.covariance_head))


    def agent(self, logger=DummyLogger(), train_steps=float("inf")):
        n_updates = (
            train_steps - self.hyperparameters["replay_start_size"]
        ) / self.hyperparameters["update_frequency"]

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters["lr"],
            fused=True
        )

        q = JointQDist(
            self.backbone,
            self.mixing_head,
            self.mean_head,
            self.covariance_head,
            optimizer,
            self.n_actions,
            self.n_mixture,
            target=FixedTarget(self.hyperparameters["target_update_frequency"], self.n_mixture, self.n_actions),
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            logger=logger,
        )
        
        
        replay_buffer = PrioritizedReplayBuffer(
            self.hyperparameters["replay_buffer_size"],
            device=self.device
        )
                

        return JointDQN(
                q,
                replay_buffer,
                self.n_actions,
                exploration=LinearScheduler(
                self.hyperparameters["initial_exploration"],
                self.hyperparameters["final_exploration"],
                self.hyperparameters["replay_start_size"],
                self.hyperparameters["final_exploration_step"]
                - self.hyperparameters["replay_start_size"],
                name="exploration",
                logger=logger,
            ),
                discount_factor=self.hyperparameters["discount_factor"],
                minibatch_size=self.hyperparameters["minibatch_size"],
                replay_start_size=self.hyperparameters["replay_start_size"],
                update_frequency=self.hyperparameters["update_frequency"],
                logger=logger,
                n_update_actions=self.hyperparameters["n_update_actions"],
                rand_actions=self.hyperparameters["rand_actions"],
                update_q=LinearScheduler(
                    initial_value=self.hyperparameters["initial_q"],
                    final_value=0.01, # 0.05
                    decay_start=0,
                    decay_end=6250,
                    name="q",
                    logger=logger,
                )
            )

    def test_agent(self):
        q_dist = JointQDist(
            copy.deepcopy(self.backbone),
            copy.deepcopy(self.mixing_head),
            copy.deepcopy(self.mean_head),
            copy.deepcopy(self.covariance_head),
            None,
            self.n_actions,
            self.n_mixture
        )
        return JointDQNTestAgent(
                q_dist, self.hyperparameters["test_exploration"]
            )


jointdqn = PresetBuilder("jointdqn", default_hyperparameters, JointDQNAtariPreset)
jointdqntest = PresetBuilder("jointdqn", default_hyperparameters, JointDQNAtariTestPreset)
jointdqn_cartpole = PresetBuilder("jointdqn_cartpole", default_hyperparameters_cartpole, JointDQNCartpolePreset)


