import numpy as np
import torch
from d3rlpy.dataset import MDPDataset
import d3rlpy

from models.vlm_adapt import VLMAdapt          # sibling package under src/
from utils.metrics import compute_goal_baseline_clip_reward

def build_vlm_model(cfg_model, device: torch.device):
    proj_dim = cfg_model.projection_dim[0] if isinstance(cfg_model.projection_dim, (list, tuple)) else cfg_model.projection_dim
    print(f'projection_dim: {proj_dim}')
    model = VLMAdapt(clip_model_name=cfg_model.clip_model, pretrained_clip=cfg_model.pretrained_clip, projection_dim=int(proj_dim)).to(device)
    state = torch.load(cfg_model.checkpoint_path, map_location=device)
    model.load_state_dict(state, strict=False)
    model.eval()
    return model

def _minmax_torch(x: torch.Tensor) -> torch.Tensor:
    x_min, x_max = x.min(), x.max()
    if (x_max - x_min) < 1e-8: return torch.zeros_like(x)
    return (x - x_min) / (x_max - x_min)

def compute_rewards_all(loader, model: VLMAdapt, device: torch.device,
                        ttt_lr: float, ttt_epochs: int,
                        methods=("no_ttt","online","offline","window","clip","clip_reg")):
    bufs = {m: [] for m in methods}
    with torch.no_grad():
        for (frames, goal_text, progress_label, valid_mask) in loader:
            valid_len = valid_mask[0].sum().long().item()
            if valid_len == 0:
                continue
            frames_seq = frames[:, :valid_len].to(device)
            goal = [goal_text[0]]

            if "no_ttt" in methods:
                out = model.inference_no_ttt(frames_seq, goal)
                bufs["no_ttt"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

            if "online" in methods:
                out = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr, ttt_epochs=ttt_epochs,
                                                  window_size=1, reset=True)
                bufs["online"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

            if "window" in methods:
                out = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr, ttt_epochs=ttt_epochs,
                                                  window_size=8, reset=True)
                bufs["window"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

            if "offline" in methods:
                out = model.offline_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr, ttt_epochs=ttt_epochs)
                bufs["offline"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

            if "clip" in methods:
                out = model.compute_clip_similarity_score(frames_seq, goal)[:, :valid_len]
                bufs["clip"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

            if "clip_reg" in methods:
                out = compute_goal_baseline_clip_reward(model.clip_feature_extractor, frames_seq, goal,
                                                        'Pick up an object', alpha=0.5)[:, :valid_len]
                bufs["clip_reg"].append(_minmax_torch(out.squeeze(0)).cpu().numpy())

    return {m: np.concatenate(bufs[m], axis=0).astype(np.float32) for m in methods}


def build_mdps_for_rewards(observations, actions, terminals, rewards_by_name, expert_rewards=None, expert=False):
    mdps = {name: MDPDataset(observations=observations, actions=actions, rewards=rewards, terminals=terminals)
            for name, rewards in rewards_by_name.items()}
    if expert_rewards is not None and expert:
        mdps["expert"] = MDPDataset(
            observations=observations,
            actions=actions,
            rewards=expert_rewards,
            terminals=terminals,
        )
    return mdps

def make_iql(device: str):
    return d3rlpy.algos.IQLConfig(
        expectile=0.7, weight_temp=3.0, max_weight=100.0,
        actor_learning_rate=1e-4, critic_learning_rate=3e-4,
        batch_size=256,
    ).create(device=device)

def train_iql_on_mdps(mdps, n_steps: int, n_steps_per_epoch: int, device: str):
    trained = {}
    for name, dataset in mdps.items():
        algo = make_iql(device)
        print(f"[IQL] training on reward={name}")
        algo.fit(dataset=dataset, n_steps=n_steps, n_steps_per_epoch=n_steps_per_epoch)
        trained[name] = algo
    return trained
