import argparse, math
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.distributions import Normal

from toy_experiments.toy_helpers import Data_Sampler
from toy_experiments.risky_bandit import generate_donut_dataset

###############################################################################
#                           Config & CLI                                      #
###############################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--n', type=int, default=10_000)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--device', type=str, default='cpu')       # '0', 'cuda:1', etc.
parser.add_argument('--dist', type=str, default='donut', choices=['donut', 'tail_risk'])
args = parser.parse_args()

# reproducibility -------------------------------------------------------------
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# device ---------------------------------------------------------------------
if args.device.isnumeric():
    device = torch.device(f'cuda:{args.device}')
else:
    device = torch.device(args.device)
###############################################################################
#                  PDF / sampler definitions for β (data distribution)
###############################################################################

# ------------------ pdf ----------------------------------------------------
def donut_pdf(x, mix_ring=0.80,
              ring_r=0.90, ring_std=0.04,
              center_std=0.08, n_gauss=16):
    """
    Approximate the true density of β using a mixture of 16 Gaussians (discretized angles).
    """
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).float()
    # ---- Ring component ----
    angles = torch.linspace(0, 2*math.pi, steps=n_gauss+1)[:-1]      # Exclude the last point
    means  = torch.stack([ring_r*torch.cos(angles), ring_r*torch.sin(angles)], 1)  # (K,2)
    log_ring = torch.stack([
        Normal(m, ring_std).log_prob(x).sum(-1) for m in means
    ]).logsumexp(0) - math.log(n_gauss)                    # mixture (1/K)
    # ---- Center component ----
    log_center = Normal(torch.zeros(2), center_std).log_prob(x).sum(-1)

    # ---- Mixture weights ----
    w_ring   = mix_ring
    w_center = 1 - mix_ring
    log_pdf  = torch.logaddexp(
        math.log(w_ring)   + log_ring,
        math.log(w_center) + log_center
    )
    return log_pdf.exp().cpu().numpy()
# ---------------------------------------------------------------------------

# ------------------ sampler ------------------------------------------------
def donut_sampler(n,
                  mix_ring=0.80,
                  ring_r=0.90, ring_std=0.04,
                  center_std=0.08,
                  clamp=1.0):
    n_ring   = int(n * mix_ring)
    n_center = n - n_ring

    # Ring samples
    theta  = 2*math.pi*torch.rand(n_ring)
    radius = torch.normal(ring_r, ring_std, size=(n_ring,))
    ring_xy = torch.stack([radius*torch.cos(theta),
                           radius*torch.sin(theta)], 1)

    # Center samples
    center_xy = torch.normal(0., center_std, size=(n_center, 2))

    samples = torch.cat([ring_xy, center_xy], 0)
    return samples.clamp_(-clamp, clamp).numpy()
# ---------------------------------------------------------------------------

def tail_pdf(x, pos=0.8, std=0.05):
    mus = [(-pos, pos), (-pos, -pos), (pos, pos), (pos, -pos)]
    pdf = torch.stack([
        Normal(torch.tensor(m), torch.tensor([std, std])).log_prob(x).sum(-1)
        for m in mus
    ]).logsumexp(0) - math.log(len(mus))
    return pdf.exp().cpu().numpy()

def tail_sampler(n, pos=0.8, std=0.05):
    corners = torch.tensor([[-pos,pos],[-pos,-pos],[pos,pos],[pos,-pos]])
    idx = torch.randint(0,4,(n,))
    means = corners[idx]
    return (means + std*torch.randn(n,2)).clamp(-1,1).numpy()

if args.dist == 'donut':
    beta_pdf = donut_pdf
    beta_sampler = donut_sampler 
else:
    beta_pdf = tail_pdf
    beta_sampler = tail_sampler


###############################################################################
#                              Data Sampler                                   #
###############################################################################
num_data = args.n
batch_size = 100
iterations = num_data // batch_size
state_dim = action_dim = 2
max_action = 1.0
# === Helper: save individual scatter panels =================================
from pathlib import Path

def save_panel(actions: np.ndarray,
               name: str,
               out_dir: Path,
               color=None, cmap=None, cvals=None,
               msize: int = 8):            # default 8 pt markers
    """
    msize : scatter marker size (points^2)
    """
     # Accept both CPU/GPU tensors
    if isinstance(actions, torch.Tensor):
        actions = actions.detach().cpu().numpy()
    # --------------------------------------
    fig, ax = plt.subplots(figsize=(2.5, 2.5))

    if cmap is not None and cvals is not None:
        ax.scatter(actions[:, 0], actions[:, 1],
                   c=cvals, cmap=cmap, s=msize, alpha=.35)
    else:
        ax.scatter(actions[:, 0], actions[:, 1],
                   color=color, s=msize, alpha=.35)

    ax.axis('off')
    ax.set_xlim(-1.1, 1.1); ax.set_ylim(-1.1, 1.1); ax.set_aspect('equal')
    out_dir.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_dir / f"{name}.pdf",
                bbox_inches='tight', pad_inches=0.)
    plt.close(fig)


# ==================================================================

def generate_tail_risk_data(num: int, device='cpu'):
    each = num // 4
    pos, std = 0.8, 0.05
    mus = [(-pos, pos), (-pos, -pos), (pos, pos), (pos, -pos)]
    dists = [Normal(torch.tensor(m), torch.tensor([std, std])) for m in mus]
    action = torch.cat([d.sample((each,)).clamp(-1, 1) for d in dists], 0)
    state = torch.zeros_like(action)
    reward = torch.randn(num, 1)  # placeholder; not used for training updates
    return Data_Sampler(state, action, reward, device)

if args.dist == 'donut':
    data_sampler = generate_donut_dataset(N=num_data, device=device)
else:
    data_sampler = generate_tail_risk_data(num_data, device)

def get_actions(agent, states):
    """
    states : (B, state_dim) torch.Tensor — already on the correct device
    return : (B, action_dim) torch.Tensor
    ------------------------------------------------------------
    Try in this order and return the first that succeeds:
    • agent.sample(states)
    • agent.actor.sample(states) / agent.actor(states)
    • agent.policy.sample(states) / agent.policy(states)
    • agent.vae.sample(states) / agent.vae.decode(states)
    • agent.sample_action(state[i]) (fallback, sequential)
    """
    # 1) Agent has .sample (e.g., ORAAC)
    if hasattr(agent, "sample"):
        try:
            return agent.sample(states)
        except TypeError:
            pass          # Skip if signature mismatch

    # 2) Via actor (Diffusion-QL / RADAC / FQL / RAFMAC …)
    if hasattr(agent, "actor"):
        # 2-A) actor.sample
        if hasattr(agent.actor, "sample"):
            return agent.actor.sample(states)
        # 2-B) Direct forward via actor(states)
        try:
            return agent.actor(states)
        except Exception:
            pass

    # 3) Via policy (e.g., BEAR-MMD)
    if hasattr(agent, "policy"):
        if hasattr(agent.policy, "sample"):
            return agent.policy.sample(states)
        try:
            return agent.policy(states)
        except Exception:
            pass

    # 4) VAE-based (BCQ, standalone BC-VAE, etc.)
    if hasattr(agent, "vae"):
        if hasattr(agent.vae, "sample"):
            return agent.vae.sample(states)
        if hasattr(agent.vae, "decode"):
            return agent.vae.decode(states)

    # 5) Fallback: call sample_action one-by-one
    if hasattr(agent, "sample_action"):
        acts = [agent.sample_action(s.cpu().numpy()) for s in states]
        return torch.tensor(acts, device=states.device, dtype=states.dtype)

    raise AttributeError("get_actions: Could not find a suitable sampling API")
###############################################################################
#                                Agents                                       #
###############################################################################
T = 50
hidden_dim =128
lr = 3e-4
discount = 0.99
tau = 0.005
eta = 0.75
risk_eta = 2.0
risk_dist = 'cvar'
alpha_cvar = 0.1 #0.075
alpha_cvar_radac = 0.08

from toy_experiments.ql_mle import QL_MLE
from toy_experiments.ql_cvae import QL_CVAE
from toy_experiments.ql_mmd import QL_MMD
from toy_experiments.ql_diffusion import QL_Diffusion
from toy_experiments.toy_oraac import ORAAC
from toy_experiments.toy_codac import CODAC
from toy_experiments.toy_radac_theory import RADAC
from toy_experiments.fql_flow import FQL
from toy_experiments.toy_rafmac import RAFMAC
from toy_experiments.toy_oraac_diffusion import ORAAC_Diffusion
from toy_experiments.toy_oraac_flow import ORAAC_Flow

# Agent definitions for testing
AGENT_DEFINITIONS = {
    'FQL':         (FQL,  dict(flow_steps=10, alpha=0.1)),
    'QL-MLE':      (QL_MLE, dict()),
    'QL-CVAE':     (QL_CVAE, dict()),
    'QL-MMD':      (QL_MMD, dict()),
    'Diffusion-QL':    (QL_Diffusion, dict(beta_schedule='vp', n_timesteps=T, eta=eta)),
    'ORAAC': (ORAAC, dict(
        lr=lr,
        n_quantiles=16,
        latent_dim=16,
        lamda=0.1,
        risk_dist=risk_dist,
        alpha=alpha_cvar
    )),
    'ORAAC-Diffusion': (ORAAC_Diffusion, dict(
        beta_schedule='vp',
        n_timesteps=T,
        risk_dist=risk_dist,
        alpha=alpha_cvar,
        n_quantiles=16,
        lamda=1.5,
        latent_dim=16,
        hidden_dim=hidden_dim,
        lr=lr
    )),
    'ORAAC-Flow': (ORAAC_Flow, dict(
        n_quantiles=16, lamda=2.5, risk_dist=risk_dist, alpha=alpha_cvar,
        flow_steps=10, hidden_dim=hidden_dim, lr=lr)),
    'RADAC':   (RADAC, dict(eta=2.5, risk_dist=risk_dist, alpha=alpha_cvar_radac, beta_schedule='vp', n_timesteps=T)),
    'RAFMAC':  (RAFMAC, dict(eta=2.5, risk_dist=risk_dist, alpha=1.0, cvar_alpha=alpha_cvar, flow_steps=10)),
    'CODAC':   (CODAC, dict(num_quantiles=32, omega=0.0, omega_final=0.1, risk_type='cvar', risk_param=0.1, lr=lr)),
}

# Specify which agents to test
AGENTS_TO_TEST = [
    'QL-MLE',
    'QL-MMD',
    'RAFMAC',
]

agent_defs = [(name, AGENT_DEFINITIONS[name][0], AGENT_DEFINITIONS[name][1]) for name in AGENTS_TO_TEST]

# Agent instantiation with proper class-specific parameters
agents = []
for name, cls, extra in agent_defs:
    common_kw = dict(discount=discount,
                     tau=tau,
                     lr=lr,
                     hidden_dim=hidden_dim)

    if cls is FQL or cls is RAFMAC:
        ag = cls(state_dim, action_dim, max_action, device,
                 **common_kw, **extra)

    elif cls is QL_Diffusion or cls is QL_MLE:
        ag = cls(state_dim, action_dim, max_action, device,
                 **common_kw, **extra)

    elif cls is ORAAC:
        ag = cls(state_dim, action_dim, max_action, device,
                **extra)               # ← discount/tau not needed

    elif cls is ORAAC_Diffusion:
        ag = cls(state_dim, action_dim, max_action, device,
                 discount=discount, tau=tau,
                 **extra)

    elif cls is RADAC:
        ag = cls(state_dim, action_dim, max_action, device,beta_pdf=beta_pdf,beta_sampler=beta_sampler,
                 **common_kw, **extra)

    elif cls is CODAC:
        ag = cls(action_dim=action_dim, device=device, **extra)
    elif cls is QL_CVAE or cls is QL_MMD or cls is QL_MLE:
        ag = cls(state_dim, action_dim, max_action, device,
                 **common_kw)
    elif cls is ORAAC_Flow:
        ag = cls(state_dim, action_dim, max_action, device, discount=discount, tau=tau,**extra)

    else:
        continue

    agents.append((name, ag))
# ---------- Color picking ----------
def pick_color(name: str) -> str:
    if 'RAFMAC' in name:
        return  '#2ca02c'
    elif 'RADAC'  in name:
        return  '#2ca02c'
    elif 'ORAAC' in name:
        return '#e41a1c'
    elif any(key in name for key in ('Diffusion-QL', 'FQL', 'QL-CVAE')):
        return '#e41a1c'
    # if 'ORAAC_Diffusion' in name:
    #     return '#ff7f0e'
    else:
        return '#e41a1c'

agent_colors = [pick_color(n) for n, _ in agents]

###############################################################################
#                          Training & Snapshot                                #
###############################################################################
num_eval = 1000
snap_every = 50
frames = {name: [] for name, _ in agents}
# epoch_dir = Path('toy_imgs/ql_panels/epochs')   # <- new directory
# epoch_dir.mkdir(parents=True, exist_ok=True)

for ep in range(args.epochs):
    for _, agent in agents:
        agent.train(data_sampler, iterations, batch_size)

    if ep % snap_every == 0 or ep == args.epochs - 1:
        with torch.no_grad():
            if agent is CODAC:
                s = torch.zeros((num_eval, state_dim), device=device)
            else:
                s = torch.randn((num_eval, state_dim), device=device)
            for name, agent in agents:
                # CODAC uses .sample(s) not .actor.sample
                # acts = agent.sample(s) if hasattr(agent, 'sample') else agent.actor.sample(s)
                acts= get_actions(agent, s)
                frames[name].append(acts.cpu().numpy())
                # ★ Additional: Save PDF with epoch number in 2.5×2.5 inch format
                # save_panel(
                #     acts,
                #     f"{name.lower().replace(' ', '_')}_ep{ep}",
                #     epoch_dir,
                #     color=pick_color(name),
                #     msize=6
                # )
        print(f'Epoch {ep:4d}/{args.epochs}')
###############################################################################
#                       Metric Collection & Visualization
###############################################################################
import numpy as np

records = []

# Save action samples after training
def save_action_samples(agents, frames, num_eval, state_dim, device, out_dir):
    """
    Save each agent's action samples to .npy files after training.
    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate evaluation states
    with torch.no_grad():
        if any(agent is CODAC for _, agent in agents):
            s = torch.zeros((num_eval, state_dim), device=device)
        else:
            s = torch.randn((num_eval, state_dim), device=device)
        
        for name, agent in agents:
            try:
                acts = get_actions(agent, s)
                acts_np = acts.cpu().numpy()
                
                # Save as .npy
                filename = f"{name.lower().replace(' ', '_').replace('-', '_')}_actions.npy"
                filepath = out_dir / filename
                np.save(filepath, acts_np)
                print(f'Saved actions for {name} → {filepath}')
                
            except Exception as e:
                print(f'Error saving actions for {name}: {e}')

###############################################################################
#                                  Plotting                                   #
###############################################################################


# agent_colors = plt.cm.tab10(np.linspace(0, 1, len(agents)))
labels = ['Ground Truth'] + [n for n, _ in agents]

cols = 3
rows = math.ceil(len(labels) / cols)
fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
axs = axs.flatten()
axis_lim = 1.1


# _, gt_a, gt_r = data_sampler.sample(num_eval)
# gt_a = gt_a.cpu().numpy()  
# # Normalize to 0–1 and map to Viridis colormap
# norm_r = ((gt_r - gt_r.min())
#           / (gt_r.max() - gt_r.min() + 1e-8)).cpu().numpy().ravel()  
# axs[0].scatter(gt_a[:, 0], gt_a[:, 1], c=norm_r, cmap='viridis', s=18, alpha=.45)
# axs[0].set_title('Ground Truth',fontweight = 'bold', fontsize = 13, pad = 6)
# --- Ground-truth panel (index 0) ---------------------------------
_, gt_actions, gt_r = data_sampler.sample(num_eval)
gt_actions = gt_actions.cpu().numpy()
gt_a = gt_actions  

norm_r = ((gt_r - gt_r.min())               # 0–1 normalization
          / (gt_r.max() - gt_r.min() + 1e-8)).cpu().numpy().ravel()
axs[0].scatter(gt_actions[:, 0], gt_actions[:, 1],
               c=norm_r, cmap='viridis', s=18, alpha=.45)
axs[0].axis('off')
axs[0].set_xlim(-1.1, 1.1); axs[0].set_ylim(-1.1, 1.1)
axs[0].set_aspect('equal')

# Save individual PDF (optional)
# save_panel(gt_actions,
#            'ground_truth',
#            Path('toy_imgs/ql_panels'),
#            cmap='viridis', cvals=norm_r,msize=6)

for idx, (name, _agent) in enumerate(agents, start=1):
    acts = frames[name][-1]

    # Grid scatter as usual…
    axs[idx].scatter(acts[:, 0], acts[:, 1],
                     color=agent_colors[idx-1], s=18, alpha=.35)

    # Optionally save individual PDFs with the same color
    # save_panel(acts,
    #            name.lower().replace(' ', '_'),
    #            Path('toy_imgs/ql_panels'),
    #            color=agent_colors[idx-1],msize=6)

tick_vals = np.arange(-1.0, 1.01, 0.25) 
for ax in axs:
    ax.set_xlim(-axis_lim, axis_lim)
    ax.set_ylim(-axis_lim, axis_lim)
    ax.tick_params(axis='both', labelsize=7, pad=6)
    ax.set_xticks(tick_vals)                # 0.25 tick marks
    ax.set_yticks(tick_vals)
    # ax.grid(alpha=0.15, linewidth=0.5) 

fig.tight_layout()
out_dir = Path('toy_imgs/ql')
out_dir.mkdir(parents=True, exist_ok=True)
outfile = out_dir / f'others_compare_{args.dist}_sd{args.seed}.png'
fig.savefig(outfile, dpi=300)
print('Saved scatter →', outfile)

# Save action samples to .npy after training
save_action_samples(agents, frames, num_eval, state_dim, device, 'toy_imgs/ql_panels')
