import argparse, json, time, os
from pathlib import Path
import hashlib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from collections import deque 

# -------- Maze construction --------------------------------------------------
HEIGHT, WIDTH = 10, 10
GOAL_COORD = (9, 9)                           # row 9, col 9 (zero-indexed)
walls = np.zeros((HEIGHT, WIDTH), dtype=int)
DOOR_LEN = 2   
# horizontal wall
walls[HEIGHT // 2, :] = 1
doors_h = np.concatenate([
    WIDTH // 4 + np.arange(DOOR_LEN),
    WIDTH * 3 // 4 + np.arange(DOOR_LEN)
])
walls[HEIGHT // 2, doors_h] = 0

# vertical wall
walls[:, WIDTH // 2] = 1
doors_v = np.concatenate([
    HEIGHT // 4 + np.arange(DOOR_LEN),
    HEIGHT * 3 // 4 + np.arange(DOOR_LEN)
])
walls[doors_v, WIDTH // 2] = 0

empty_states = np.where(walls.flatten() == 0)[0]
NUM_STATES     = HEIGHT * WIDTH
NUM_ACTIONS    = 5           # stay, down, up, right, left
A_TO_DELTA     = np.array([[0, 0],
                           [1, 0], [-1, 0],
                           [0, 1], [0, -1]])
START_STATE = np.ravel_multi_index((0, 0), walls.shape) 
GOAL_STATE = np.ravel_multi_index(GOAL_COORD, walls.shape)

nA = NUM_ACTIONS
nS = NUM_STATES 


# -------- Environment helpers -------------------------------------------------------
def step(state: int, action: int) -> int:
    """Deterministic step function for the maze."""
    di, dj = A_TO_DELTA[action]
    i, j   = np.unravel_index(state, walls.shape)
    ni, nj = i + di, j + dj
    if 0 <= ni < HEIGHT and 0 <= nj < WIDTH and walls[ni, nj] == 0:
        return np.ravel_multi_index((ni, nj), walls.shape)
    return state  # blocked: stay in place

def is_near_goal(s):
    """Check if state is near the goal (within Manhattan distance 1)"""
    si, sj = np.unravel_index(s, walls.shape)
    gi, gj = np.unravel_index(GOAL_STATE, walls.shape)
    return abs(si - gi) + abs(sj - gj) <= 1

def get_reward(state: int, action: int, next_state: int) -> float:
    """Get reward for transitioning from state to next_state via action"""
    if next_state == GOAL_STATE or is_near_goal(next_state):
        return 1.0
    return 0.0

def is_done(state: int) -> bool:
    """Check if episode is done (reached goal or near goal)"""
    return state == GOAL_STATE or is_near_goal(state)



# -------- SGCRL Agent -------------------------------------------------------------

class SGCRLAgent:
    """
    State-Goal Contrastive Reinforcement Learning Agent
    """
    
    def __init__(self, nState, nAction, rep_dim=16, episodes_per_upd=5, 
                 lr_psi=1e-3, replay_capacity=1000, max_steps=100, batch_size=128,
                 gamma=0.99, entropy_coeff=0.1,
                 max_episodes=50000, plot_mult=None, run_dir=None):
        """
        Initialize SGCRL agent
        
        Args:
            nState - int - number of states
            nAction - int - number of actions
            rep_dim - int - representation dimension
            episodes_per_upd - int - episodes between representation updates
            lr_psi - float - learning rate for psi
            replay_capacity - int - maximum replay buffer size
            max_steps - int - maximum steps per episode
            batch_size - int - batch size for training
            gamma - float - discount factor for trajectory sampling
            entropy_coeff - float - entropy coefficient for action selection
            max_episodes - int - maximum training episodes
            plot_mult - int - plotting frequency
            run_dir - Path - directory for saving results
        """
        # Store basic parameters
        self.nState = nState
        self.nAction = nAction
        self.rep_dim = rep_dim
        self.episodes_per_upd = episodes_per_upd
        self.lr_psi = lr_psi
        self.replay_capacity = replay_capacity
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.gamma = gamma
        self.entropy_coeff = entropy_coeff
        self.max_episodes = max_episodes
        self.norm = True
        self.plain_goal = False
        
        # Set up goal
        self.goal = np.ravel_multi_index(GOAL_COORD, walls.shape)
        
        # Initialize representation
        self.psi = np.empty((nState, rep_dim))
        
        # Initialize goal embedding
        self.psi_goal = np.random.randn(rep_dim) * 0.1
        self.psi[self.goal] = self.psi_goal
        if self.norm:
            psi_norm = np.linalg.norm(self.psi[self.goal]) + 1e-8
            self.psi[self.goal] /= psi_norm
            self.psi_goal /= psi_norm
        
        # Initialize other states
        for s in range(nState):
            if s != self.goal:
                self.psi[s] = self.psi_goal + np.random.randn(rep_dim) * 0.1
        
        # Normalize psi to unit vectors
        if self.norm:
            psi_norms = np.linalg.norm(self.psi, axis=1, keepdims=True) + 1e-8
            self.psi /= psi_norms
        
        # Replay buffer and tracking
        self.replay = []
        self.visited_states = set()
        self.visited_counts = []
        self.trajectories = []
        self.episode_rewards = []
        self.episode_lengths = []
        self.loss_history = []
        self.success_list = []
        self.eval_success_list = []
        
        # Plotting and directory setup
        self.plot_mult = plot_mult if plot_mult is not None else max_episodes // 50
        self.run_dir = run_dir
        if run_dir:
            print(f"Plot every {self.plot_mult} episodes")

    def select_action(self, s: int, g: int) -> int:
        """
        Softmax action selection based on ψ(s') · ψ(g) similarity,
        scaled by 1 / entropy_coeff (i.e., inverse temperature)
        """
        goal_vec = self.psi[g]

        similarities = []
        for a in range(self.nAction):
            ns = step(s, a)
            sim = self.psi[ns] @ goal_vec
            similarities.append(sim)

        # Apply entropy regularization (inverse temperature)
        logits = np.array(similarities)
        inverse_temp = 1.0 / self.entropy_coeff
        logits *= inverse_temp

        # Softmax over scaled logits
        exp_logits = np.exp(logits - np.max(logits))
        probs = exp_logits / np.sum(exp_logits)

        return np.random.choice(self.nAction, p=probs)

    def eval_action(self, s: int, g: int) -> int:
        """
        Deterministic evaluation policy: chooses the action with max cosine similarity
        between ψ(s') and ψ(g). No exploration or sampling.
        """
        goal_vec = self.psi[g]
        goal_norm = np.linalg.norm(goal_vec) + 1e-8

        best_a = 0
        best_val = -np.inf

        for a in range(self.nAction):
            ns = step(s, a)
            ns_vec = self.psi[ns]
            val = ns_vec @ goal_vec
            if val > best_val:
                best_val = val
                best_a = a

        return best_a

    def collect_episode(self):
        """Generate one trajectory and push into replay buffer."""
        step_success = []
        traj = [START_STATE]
        for _ in range(self.max_steps):
            a = self.select_action(traj[-1], self.goal)
            ns = step(traj[-1], a)
            traj.append(ns)
            success = (ns == self.goal) or is_near_goal(ns)
            step_success.append(1 if success else 0)

        self.replay.append(traj)
        if len(self.replay) > self.replay_capacity:
            self.replay.pop(0)

        return step_success

    def run_eval_episode(self):
        """Generate one evaluation trajectory."""
        step_success = []
        traj = [START_STATE]
        for _ in range(self.max_steps):
            a = self.eval_action(traj[-1], self.goal)
            ns = step(traj[-1], a)
            traj.append(ns)
            success = (ns == self.goal) or is_near_goal(ns)
            step_success.append(1 if success else 0)

        return step_success

    def update_representations(self) -> float:
        """
        Vectorised contrastive update
        """
        if len(self.replay) < 2:
            return 0.0

        traj_ids = np.random.choice(len(self.replay), self.batch_size, replace=True)
        s_list, sp_list = [], []
        for idx in traj_ids:
            traj = self.replay[idx]
            if len(traj) < 2:
                continue
            i = np.random.randint(0, len(traj) - 1)
            remaining = len(traj) - i
            w = self.gamma ** np.arange(remaining)
            w /= w.sum()
            j = i + np.random.choice(remaining, p=w)

            s_list.append(traj[i])
            sp_list.append(traj[j])

        if not s_list:
            return 0.0

        s_batch = np.asarray(s_list, dtype=np.int32)
        sp_batch = np.asarray(sp_list, dtype=np.int32)
        B = len(s_batch)

        # Gather current embeddings
        psi_s = self.psi[s_batch]
        psi_p = self.psi[sp_batch]

        # Column-wise soft-max probabilities P
        dots = psi_s @ psi_p.T
        dots -= dots.max(axis=0, keepdims=True)
        exp_logits = np.exp(dots)
        P = exp_logits / exp_logits.sum(axis=0, keepdims=True)

        diag_P = np.diag(P)
        nll = -np.mean(np.log(diag_P + 1e-12))

        # Anchor-state updates Δψ(s_j)
        coeff = np.eye(B) - P
        anchor_update = self.lr_psi * (coeff @ psi_p)
        np.add.at(self.psi, s_batch, anchor_update)

        # Positive-state updates Δψ(sp_k)
        expected_anchor = (P.T @ psi_s)
        pos_update = self.lr_psi * (psi_s - expected_anchor)
        np.add.at(self.psi, sp_batch, pos_update)

        # Normalize ψ to unit vectors (L2 norm)
        if self.norm:
            psi_norms = np.linalg.norm(self.psi, axis=1, keepdims=True) + 1e-8
            self.psi /= psi_norms

        return nll

    def trajectory_to_coordinates(self, trajectory):
        """Convert trajectory of states to (x, y) coordinates for plotting"""
        if not trajectory:
            return [], []
        
        coords = [np.unravel_index(state, walls.shape) for state in trajectory]
        y_coords = [coord[0] for coord in coords]  # row indices
        x_coords = [coord[1] for coord in coords]  # column indices
        return x_coords, y_coords

    def save_visualization(self, episode, latest_trajs):
        """Save visualization plots"""
        if self.run_dir is None:
            return
            
        goal_vec = self.psi[self.goal]

        goal_norm = np.linalg.norm(goal_vec) + 1e-8
        sim = (self.psi @ goal_vec) / (np.linalg.norm(self.psi, axis=1) * goal_norm + 1e-8)
        sim_map = sim.reshape(walls.shape)

        fig, ax = plt.subplots(figsize=(6, 6))

        # Remove border and ticks
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])

        # Background: similarity heat-map
        im = ax.imshow(sim_map, cmap="viridis", origin="lower", 
                      vmin=min(0.0, np.min(sim_map)), vmax=1)
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04,
                    label=r"ψ-similarity to goal")
        cbar.ax.set_ylabel(r"ψ-similarity to goal", fontsize=40)
        cbar.ax.tick_params(labelsize=18)  # Increase colorbar tick label size

        # Overlay walls
        ax.imshow(walls, cmap=plt.cm.binary, vmin=0, vmax=1,
                 alpha=0.5, origin="lower")

        # Draw recent trajectories
        for t in latest_trajs:
            coords = np.array([np.unravel_index(s, walls.shape) for s in t])
            ax.plot(coords[:, 1], coords[:, 0], color="black", linewidth=2.0, alpha=0.8)

        # Highlight start and goal
        ax.scatter(*np.unravel_index(START_STATE, walls.shape)[::-1],
                  marker="o", s=200, c="lime", label="Start State", zorder=10, edgecolors='black', linewidth=2, clip_on=False)
        ax.scatter(*np.unravel_index(self.goal, walls.shape)[::-1],
                  marker="*", s=300, c="red", label="Goal State", zorder=10, edgecolors='black', linewidth=2, clip_on=False)
        #ax.legend(loc="upper left", frameon=False)

        ax.set_title(f"Trial {episode}", fontsize=24)
        
        if self.psi.shape[1] == 2:
            X, Y = np.meshgrid(np.arange(WIDTH), np.arange(HEIGHT))
            U = np.zeros_like(X, dtype=float)
            V = np.zeros_like(Y, dtype=float)

            for idx in range(self.nState):
                if walls.flat[idx] == 0:
                    i, j = np.unravel_index(idx, walls.shape)
                    U[i, j], V[i, j] = self.psi[idx]

            ax.quiver(X, Y, U, V, color="white", scale=1.5, scale_units="xy", 
                     angles='xy', width=0.005)
        
        fig.tight_layout()
        fig.savefig(f"{self.run_dir}/trajs_ep{episode:05d}_cos.png", dpi=150)
        plt.close(fig)


    def train(self):
        """Main training loop"""
        num_trajs_to_track = 128
        
        for ep in tqdm(range(0, self.max_episodes), desc="episodes"):
            ep_success = self.collect_episode()
            latest_trajs = self.replay[-1:]  # For visualization (single trajectory)
            self.success_list.append(np.mean(ep_success))
            
            for state in self.replay[-1]:
                self.visited_states.add(state)
            self.visited_counts.append(len(self.visited_states))

            # Perform SGD update every N episodes
            if ep % self.episodes_per_upd == 0:
                loss = self.update_representations()
                self.loss_history.append(loss)

            # Visualize and save plots
            if ep % self.plot_mult == 0 or ep == 0:
                eval_ep_success = self.run_eval_episode()
                self.eval_success_list.append(np.mean(eval_ep_success))
                
                self.save_visualization(ep, latest_trajs)
                
                # Save success rate plot
                if self.success_list and self.run_dir:
                    fig, ax = plt.subplots(figsize=(8, 4))
                    successes = [1 if r > 0 else 0 for r in self.success_list]
                    success_rate = np.cumsum(successes) / np.arange(1, len(successes) + 1)
                    ax.plot(success_rate)
                    ax.set_xlabel('Episode')
                    ax.set_ylabel('Success Rate')
                    ax.set_title('Success Rate Over Time')
                    fig.tight_layout()
                    fig.savefig(f"{self.run_dir}/success_rate.png", dpi=150)
                    plt.close(fig)


                # Save loss curve
                if self.loss_history and self.run_dir:
                    fig_l, ax_l = plt.subplots(figsize=(4, 2.5))
                    xs = np.arange(len(self.loss_history)) * self.episodes_per_upd
                    ax_l.plot(xs, self.loss_history, linewidth=1)
                    ax_l.set_xlabel("episode")
                    ax_l.set_ylabel("contrastive NLL")
                    ax_l.set_title("training loss")
                    fig_l.tight_layout()
                    fig_l.savefig(f"{self.run_dir}/loss.png", dpi=150)
                    plt.close(fig_l)



# -------- Configuration and Run Directory Setup -----------------------------------------------

SLURM_DEFAULTS = {
    "rep_dim": 16,
    "episodes_per_upd": 1,
    "lr_psi": 0.01,
    "replay_capacity": 1000,
    "max_steps": 50,
    "batch_size": 128,
    "num_episodes": 20000,
    "gamma": 0.99,
    "seed": 1,  
    "entropy_coeff": 0.1,
    "plot_mult": None 
}

def parse_args(use_defaults=True):
    """Parse arguments, with option to use hardcoded defaults from slurm file."""
    p = argparse.ArgumentParser()
    p.add_argument("--rep-dim",          type=int,   default=16)
    p.add_argument("--episodes-per-upd", type=int,   default=5)
    p.add_argument("--max-steps",        type=int,   default=100)
    p.add_argument("--batch-size",       type=int,   default=128)
    p.add_argument("--lr-psi",           type=float, default=1e-3)
    p.add_argument("--num-episodes",     type=int,   default=50_000)
    p.add_argument("--gamma",            type=float, default=0.99)
    p.add_argument("--seed",             type=int,   default=1)
    p.add_argument("--replay-capacity",  type=int,   default=1000)
    p.add_argument("--entropy_coeff",    type=float, default=0.1)
    p.add_argument("--plot-mult",        type=int,   default=1_000,
                   help="save plots every plot-mult * episodes_per_upd episodes")
    
    if use_defaults:
        # Override with hardcoded defaults from slurm file
        args = p.parse_args([])  # Empty list to avoid reading command line args
        for k, v in SLURM_DEFAULTS.items():
            setattr(args, k, v)
    else:
        # Use command line arguments
        args = p.parse_args()
    
    return args

args = parse_args(use_defaults=True)
print("Using hyperparameters:")
for k, v in vars(args).items():
    print(f"{k}: {v}")

seed = args.seed
max_episodes = args.num_episodes
max_steps = args.max_steps
plot_mult = max_episodes // 50

# SGCRL-specific parameters
rep_dim = args.rep_dim
episodes_per_upd = args.episodes_per_upd
lr_psi = args.lr_psi
replay_capacity = args.replay_capacity
batch_size = args.batch_size
gamma = args.gamma
entropy_coeff = args.entropy_coeff

np.random.seed(seed)

config = {
    "seed": seed,
    "max_episodes": max_episodes,
    "max_steps": max_steps,
    "plot_mult": plot_mult,
    "rep_dim": rep_dim,
    "episodes_per_upd": episodes_per_upd,
    "lr_psi": lr_psi,
    "replay_capacity": replay_capacity,
    "batch_size": batch_size,
    "gamma": gamma,
    "entropy_coeff": entropy_coeff,
    "HEIGHT": HEIGHT,
    "WIDTH": WIDTH,
    "GOAL_COORD": GOAL_COORD,
    "algorithm": "SGCRL"
}

# Create run directory
hash_basis = {k: v for k, v in config.items() if k != "seed"}
hash_json = json.dumps(hash_basis, sort_keys=True).encode()
uid = hashlib.sha1(hash_json).hexdigest()[:10]

run_dir = Path(f"/home/rebuttal/sgcrl_maze/seed{seed}") / uid
run_dir.mkdir(parents=True, exist_ok=True)

config["uid"] = uid
config["timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")

with open(run_dir / "config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"[INFO] Run directory: {run_dir}")
print(f"[INFO] Config UID: {uid}")
print(f"[INFO] Configuration: {config}")

with open(f"{run_dir}/index.csv", "a") as idx:
    if idx.tell() == 0:
        idx.write("uid,seed,max_episodes,max_steps,rep_dim,episodes_per_upd,lr_psi,entropy_coeff,timestamp\n")
    idx.write(f"{uid},{seed},{max_episodes},{max_steps},{rep_dim},{episodes_per_upd},{lr_psi},{entropy_coeff},{config['timestamp']}\n")


# Create and train the SGCRL agent
print("Creating SGCRL agent...")
sgcrl_agent = SGCRLAgent(
    nState=nS, 
    nAction=nA, 
    rep_dim=rep_dim,
    episodes_per_upd=episodes_per_upd,
    lr_psi=lr_psi,
    replay_capacity=replay_capacity,
    max_steps=max_steps,
    batch_size=batch_size,
    gamma=gamma,
    entropy_coeff=entropy_coeff,
    max_episodes=max_episodes,
    plot_mult=plot_mult,
    run_dir=run_dir
)

print("Starting SGCRL training...")
sgcrl_agent.train()

    

print("Training complete!")
