import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import time
import os
from collections import deque, namedtuple
import gymnasium as gym
from minigrid.wrappers import ImgObsWrapper, RGBImgObsWrapper
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
env_name = str(sys.argv[1])
seed_n = int(sys.argv[2])
np.random.seed(int(seed_n)+101)
os.mkdir('results')

class ReplayBuffer:
    def __init__(self, capacity, device=None):
        self.buffer = deque(maxlen=capacity)
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.experience = namedtuple("Experience", 
                                    field_names=["state", "action", "reward", "next_state", "done"])
    
    def add(self, state, action, reward, next_state, done):

        e = self.experience(state, action, reward, next_state, done)
        self.buffer.append(e)
    
    def sample(self, batch_size):

        experiences = random.sample(self.buffer, k=batch_size)
        
        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)
        
        return (states, actions, rewards, next_states, dones)

    def __len__(self):

        return len(self.buffer)


class QNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(QNetwork, self).__init__()
      
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[2], 32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )
        

        conv_out_size = self._get_conv_out(input_shape)
        
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )
    
    def _get_conv_out(self, shape):
        with torch.no_grad():

            conv_out = self.conv(torch.zeros(1, shape[2], shape[0], shape[1]))
            return int(np.prod(conv_out.size()))
    
    def forward(self, x):

        if len(x.shape) == 3: 
            x = x.unsqueeze(0)
        

        x = x.permute(0, 3, 1, 2).float()
        
        conv_out = self.conv(x).reshape(x.size(0), -1)
        return self.fc(conv_out)

class EnsembleDQNAgent:
    def __init__(self, state_shape, num_actions, ensemble_size=5, gamma=0.99,
                 lr=1e-3, batch_size=64, buffer_size=int(1e4), update_freq=4,
                 target_update_freq=1000, tau=0.005, device=None):

        self.state_shape = state_shape
        self.num_actions = num_actions
        self.ensemble_size = ensemble_size
        self.gamma = gamma
        self.batch_size = batch_size
        self.update_freq = update_freq
        self.target_update_freq = target_update_freq
        self.tau = tau
        self.beta = np.zeros((self.ensemble_size,2))
        self.beta[:,0] = 0.001# init beta 
        self.beta[:,1] = 1
        self.index_label = 0
        self.weight = np.zeros(self.ensemble_size)
        
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_networks = [QNetwork(state_shape, num_actions).to(self.device) for _ in range(ensemble_size)]
        self.target_networks = [QNetwork(state_shape, num_actions).to(self.device) for _ in range(ensemble_size)]
        
        for i in range(ensemble_size):
            self.target_networks[i].load_state_dict(self.q_networks[i].state_dict())
            self.target_networks[i].eval()

        self.optimizers = [optim.Adam(net.parameters(), lr=lr) for net in self.q_networks]
        self.memory = ReplayBuffer(buffer_size, device=self.device)
        self.steps = 0
    
    def select_action(self, state, epsilon=0.1):
        if random.random() < epsilon:
            return random.randrange(self.num_actions)
        with torch.no_grad():
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
            q_values = torch.zeros(self.num_actions).to(self.device)
            q_values += self.q_networks[self.index_label](state_tensor).squeeze(0)
            return q_values.argmax().item()
    
    def update(self):
        if len(self.memory) < self.batch_size:
            return
        
        if self.steps % self.update_freq != 0:
            return
    
        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        

        with torch.no_grad():
            next_q_values = torch.zeros((self.batch_size, self.ensemble_size, self.num_actions)).to(self.device)

            for i in range(self.ensemble_size):
                next_q_values[:, i, :] = self.target_networks[i](next_states) * self.weight[i]

            next_q_values = next_q_values.sum(dim=1)

            next_q_values = next_q_values.max(dim=1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        for i in range(self.ensemble_size):
            current_q_values = self.q_networks[i](states).gather(1, actions).squeeze(1)
            loss = F.mse_loss(current_q_values, target_q_values)
            
            self.optimizers[i].zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.q_networks[i].parameters(), 1.0)
            self.optimizers[i].step()

        if self.steps % self.target_update_freq == 0:
            self._soft_update_target_networks()
    
    def _soft_update_target_networks(self):

        for i in range(self.ensemble_size):
            for target_param, param in zip(self.target_networks[i].parameters(), 
                                          self.q_networks[i].parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def store_experience(self, state, action, reward, next_state, done):

        self.memory.add(state, action, reward, next_state, done)
        self.steps += 1
    
    def save(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'q_networks': [net.state_dict() for net in self.q_networks],
            'target_networks': [net.state_dict() for net in self.target_networks],
            'steps': self.steps
        }, path)
        print(f"Model saved to {path}")
    
    def load(self, path):
        checkpoint = torch.load(path)
        for i in range(self.ensemble_size):
            self.q_networks[i].load_state_dict(checkpoint['q_networks'][i])
            self.target_networks[i].load_state_dict(checkpoint['target_networks'][i])
        self.steps = checkpoint['steps']
        print(f"Model loaded from {path}")

def train_agent(env_name="MiniGrid-Empty-8x8-v0", 
                num_episodes=500,
                max_steps=200,
                ensemble_size=5,
                gamma=0.99,
                lr=1e-3,
                batch_size=64,
                buffer_size=int(1e5),
                update_freq=4,
                target_update_freq=1000,
                tau=0.005,
                epsilon_start=1.0,
                epsilon_end=0.01,
                epsilon_decay=0.995,
                eval_freq=20000,
                eval_episodes=10,
                save_dir="models",
                device=None):

    device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    

    env = gym.make(env_name)
    env = ImgObsWrapper(env)  
    
    state = env.reset()
    state_shape = env.observation_space.shape 
    num_actions = env.action_space.n
    
    print(f"State shape: {state_shape}, Action space: {num_actions}")
    

    agent = EnsembleDQNAgent(
        state_shape=state_shape,
        num_actions=num_actions,
        ensemble_size=ensemble_size,
        gamma=gamma,
        lr=lr,
        batch_size=batch_size,
        buffer_size=buffer_size,
        update_freq=update_freq,
        target_update_freq=target_update_freq,
        tau=tau,
        device=device
    )
    

    episode_rewards = []
    epsilon_values = []
    start_time = time.time()
    

    for episode in range(num_episodes):
        state = env.reset() 
        state = state[0]
        total_reward = 0
        done = False
        step = 0
        

        epsilon = max(epsilon_end, epsilon_start * (epsilon_decay ** episode))

        for i in range(5):
            agent.weight[i] = np.random.beta(agent.beta[i,0],agent.beta[i,1])#sample w_k
        agent.weight = agent.weight/np.sum(agent.weight) #weight = p_k
        agent.index_label = np.argmax(agent.weight)
        while not done and step < max_steps:
            action = agent.select_action(state, epsilon)
            next_state, reward, done1, done2, _ = env.step(action)
            
            next_state = next_state[:,:,:]
            if reward>0:
                reward = 1

            done = False

            if done1 or done2:
                done = True
                if reward == 1:
                    agent.beta[agent.index_label,0] = agent.beta[agent.index_label,0] + 1
                else:
                    agent.beta[agent.index_label,1] = agent.beta[agent.index_label,1] + 1

            agent.store_experience(state, action, reward, next_state, done)

            agent.update()            
            state = next_state[:,:,:]
            total_reward += reward
            step += 1

        episode_rewards.append(total_reward)
        epsilon_values.append(epsilon)
        

        avg_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
        np.save("results/bedqn_"+env_name+str(seed_n)+".npy",np.array(episode_rewards))
        print(f"Episode {episode+1}/{num_episodes} | "
              f"Reward: {total_reward:.2f} | "
              f"Avg Reward (50): {avg_reward:.2f} | "
              f"Epsilon: {epsilon:.3f} | "
              f"Buffer: {len(agent.memory)}/{buffer_size}")
        

        if (episode + 1) % eval_freq == 0:
            eval_reward = evaluate_agent(env, agent, eval_episodes, max_steps)
            print(f"\n>>> Evaluation after {episode+1} episodes: "
                  f"Average Reward = {eval_reward:.2f} <<<\n")
        

        if (episode + 1) % 100000 == 0:
            save_path = os.path.join(save_dir, f"ensemble_dqn_{env_name}_ep{episode+1}.pt")
            agent.save(save_path)
    

    final_save_path = os.path.join(save_dir, f"ensemble_dqn_{env_name}_final.pt")
    agent.save(final_save_path)
    
    training_time = time.time() - start_time
    print(f"\nTraining completed in {training_time/60:.2f} minutes")
    print(f"Average reward over last 100 episodes: {np.mean(episode_rewards[-100:]):.2f}")

    
    return agent, episode_rewards

def evaluate_agent(env, agent, num_episodes=10, max_steps=200):

    total_rewards = []
    
    for _ in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        step = 0
        
        while not done and step < max_steps:
            action = agent.select_action(state, epsilon=0.0)  
            next_state, reward, done, _, _= env.step(action)
            next_state = next_state
            
            state = next_state
            total_reward += reward
            step += 1
        
        total_rewards.append(total_reward)
    
    return np.mean(total_rewards)

def plot_training_results(rewards, epsilons, env_name):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(rewards)
    plt.plot(np.convolve(rewards, np.ones(20)/20, mode='valid'), 'r', linewidth=2)
    plt.title(f'Training Rewards - {env_name}')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    
    plt.subplot(1, 2, 2)
    plt.plot(epsilons)
    plt.title('Exploration Rate (Epsilon)')
    plt.xlabel('Episode')
    plt.ylabel('Epsilon')
    
    plt.tight_layout()
    plt.savefig(f'training_results_{env_name}.png')
    plt.close()
    print("Training results plot saved")

if __name__ == "__main__":
    os.makedirs("models", exist_ok=True)


    train_params = {
        "env_name": env_name,  
        "num_episodes": 2000,
        "max_steps": 50,
        "ensemble_size": 5,
        "gamma": 0.99,
        "lr": 5e-4,
        "batch_size": 32,
        "buffer_size": int(5e4),
        "update_freq": 1,
        "target_update_freq": 500,
        "tau": 1,
        "epsilon_start": 0.1,
        "epsilon_end": 0.02,
        "epsilon_decay": 0.995,
        "eval_freq": 2000000,
        "eval_episodes": 10,
        "save_dir": "models"
    }
    
    print("Starting Ensemble DQN training with parameters:")
    for k, v in train_params.items():
        print(f"  {k}: {v}")
    trained_agent, rewards = train_agent(**train_params)
    
    print("\nTraining completed! You can now use the trained agent for inference.")