import csv
from datetime import datetime
import json
from pathlib import Path
import random
import string
import sys

import numpy as np
import torch
import torch.nn as nn
import os
from typing import Dict, List, Tuple


try:
    import faiss  # pip install faiss-cpu / faiss-gpu
except ImportError:
    faiss = None
    print("[utils] Warning: faiss not installed, build_faiss_index will fail if called.")


DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Squeeze(nn.Module):
    def __init__(self, dim=None):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return x.squeeze(dim=self.dim)


def mlp(dims, activation=nn.ReLU, output_activation=None, squeeze_output=False):
    n_dims = len(dims)
    assert n_dims >= 2, 'MLP requires at least two dims (input and output)'

    layers = []
    for i in range(n_dims - 2):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        layers.append(activation())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    if output_activation is not None:
        layers.append(output_activation())
    if squeeze_output:
        assert dims[-1] == 1
        layers.append(Squeeze(-1))
    net = nn.Sequential(*layers)
    net.to(dtype=torch.float32)
    return net


def compute_batched(f, xs):
    return f(torch.cat(xs, dim=0)).split([len(x) for x in xs])


def update_exponential_moving_average(target, source, alpha):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha)


def torchify(x):
    x = torch.from_numpy(x)
    if x.dtype is torch.float64:
        x = x.float()
    x = x.to(device=DEFAULT_DEVICE)
    return x



def return_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0., 0
    for r, d in zip(dataset['rewards'], dataset['terminals']):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0., 0
    # returns.append(ep_ret)    # incomplete trajectory
    lengths.append(ep_len)      # but still keep track of number of steps
    assert sum(lengths) == len(dataset['rewards'])
    return min(returns), max(returns)


# dataset is a dict, values of which are tensors of same first dimension
def sample_batch(dataset, batch_size):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), dataset[k].device
    for v in dataset.values():
        assert len(v) == n, 'Dataset values must have same length'
    indices = torch.randint(low=0, high=n, size=(batch_size,), device=device)
    return {k: v[indices] for k, v in dataset.items()}


def evaluate_policy(env, policy, max_episode_steps, deterministic=True):
    obs = env.reset()
    total_reward = 0.
    for _ in range(max_episode_steps):
        with torch.no_grad():
            action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy()
        next_obs, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            break
        else:
            obs = next_obs
    return total_reward


def set_seed(seed, env=None):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    if env is not None:
        env.seed(seed)


def _gen_dir_name():
    now_str = datetime.now().strftime('%m-%d-%y_%H.%M.%S')
    rand_str = ''.join(random.choices(string.ascii_lowercase, k=4))
    return f'{now_str}_{rand_str}'

class Log:
    def __init__(self, root_log_dir, cfg_dict,
                 txt_filename='log.txt',
                 csv_filename='progress.csv',
                 cfg_filename='config.json',
                 flush=True):
        self.dir = Path(root_log_dir)/_gen_dir_name()
        self.dir.mkdir(parents=True)
        self.txt_file = open(self.dir/txt_filename, 'w')
        self.csv_file = None
        (self.dir/cfg_filename).write_text(json.dumps(cfg_dict))
        self.txt_filename = txt_filename
        self.csv_filename = csv_filename
        self.cfg_filename = cfg_filename
        self.flush = flush

    def write(self, message, end='\n'):
        now_str = datetime.now().strftime('%H:%M:%S')
        message = f'[{now_str}] ' + message
        for f in [sys.stdout, self.txt_file]:
            print(message, end=end, file=f, flush=self.flush)

    def __call__(self, *args, **kwargs):
        self.write(*args, **kwargs)

    def row(self, dict):
        if self.csv_file is None:
            self.csv_file = open(self.dir/self.csv_filename, 'w', newline='')
            self.csv_writer = csv.DictWriter(self.csv_file, list(dict.keys()))
            self.csv_writer.writeheader()

        self(str(dict))
        self.csv_writer.writerow(dict)
        if self.flush:
            self.csv_file.flush()

    def close(self):
        self.txt_file.close()
        if self.csv_file is not None:
            self.csv_file.close()




def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum



def save_dataset_embeddings(
    variant, model, trajectories, state_mean, state_std, device="cuda"
):
    model.eval().to(device)

    K = int(variant["K"])
    batch_size = int(variant.get("batch_size", 64)) * 2
    state_dim = int(variant["state_dim"])
    act_dim = int(variant["act_dim"])
    max_ep_len = int(variant.get("max_ep_len", 1000))
    scale = float(variant.get("scale", 1000.0))

    data_indices = []
    for traj_idx, traj in enumerate(trajectories):
        for t in range(traj["observations"].shape[0]):
            data_indices.append((traj_idx, t))
    all_embeddings, all_actions, all_rtgs = [], [], []
    with torch.no_grad():
        for i in range(0, len(data_indices), batch_size):
            batch_indices = data_indices[i:i+batch_size]

            batch_s, batch_a, batch_r = [], [], []
            batch_rtg, batch_ts, batch_mask = [], [], []

            for traj_idx, t in batch_indices:
                traj = trajectories[traj_idx]
                start = max(0, t - K + 1)
                end = t + 1
                L = end - start

                s = traj["observations"][start:end].reshape(1, -1, state_dim)
                a = traj["actions"][start:end].reshape(1, -1, act_dim)
                r = traj["rewards"][start:end].reshape(1, -1, 1)
                ts = np.arange(start, end).reshape(1, -1)
                ts = np.clip(ts, 0, max_ep_len - 1)

                full_rtg = discount_cumsum(traj["rewards"], gamma=1.0)
                rtg = full_rtg[start:end].reshape(1, -1, 1)

                if L < K:
                    pad = K - L
                    s = np.concatenate([np.zeros((1, pad, state_dim)), s], axis=1)
                    a = np.concatenate([np.ones((1, pad, act_dim)) * -10.0, a], axis=1)
                    r = np.concatenate([np.zeros((1, pad, 1)), r], axis=1)
                    rtg = np.concatenate([np.zeros((1, pad, 1)), rtg], axis=1)
                    ts = np.concatenate([np.zeros((1, pad)), ts], axis=1)
                    mask = np.concatenate([np.zeros((1, pad)), np.ones((1, L))], axis=1)
                else:
                    mask = np.ones((1, K))
                s = (s - state_mean) / state_std
                rtg = rtg / scale

                batch_s.append(s)
                batch_a.append(a)
                batch_r.append(r)
                batch_rtg.append(rtg)
                batch_ts.append(ts)
                batch_mask.append(mask)
            s_t = torch.from_numpy(np.concatenate(batch_s, 0)).float().to(device)
            a_t = torch.from_numpy(np.concatenate(batch_a, 0)).float().to(device)
            r_t = torch.from_numpy(np.concatenate(batch_r, 0)).float().to(device)
            rtg_t = torch.from_numpy(np.concatenate(batch_rtg, 0)).float().to(device)
            ts_t = torch.from_numpy(np.concatenate(batch_ts, 0)).long().to(device)
            m_t = torch.from_numpy(np.concatenate(batch_mask, 0)).float().to(device)

            state_codes = model.forward_codes(
                states=s_t,
                actions=a_t,
                rewards=r_t,
                returns_to_go=rtg_t,
                timesteps=ts_t,
                attention_mask=m_t,
            )
            emb = state_codes[:, -1, :]          # [B, H]
            act = a_t[:, -1, :]                  # [B, act_dim]
            cur_rtg = rtg_t[:, -1, 0]            # [B]

            valid = (cur_rtg > 1e-6)
            if valid.sum() == 0:
                continue
            all_embeddings.append(emb[valid].cpu().numpy().astype("float32"))
            all_actions.append(act[valid].cpu().numpy().astype("float32"))
            all_rtgs.append(cur_rtg[valid].cpu().numpy().astype("float32"))
            if (i // batch_size) % 50 == 0:
                print(f"[emb] processed {i}/{len(data_indices)} steps...")
    all_embeddings = np.concatenate(all_embeddings, 0)
    all_actions = np.concatenate(all_actions, 0)
    all_rtgs = np.concatenate(all_rtgs, 0)
    save_dir = variant["model_save_dir"]
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"{variant['env']}_{variant['dataset']}_embeddings.npz")
    np.savez(save_path, embeddings=all_embeddings, actions=all_actions, rtgs=all_rtgs)

    print(f"Embeddings saved to: {save_path}")
    print("embeddings:", all_embeddings.shape, "actions:", all_actions.shape, "rtgs:", all_rtgs.shape)
    return save_path

def build_faiss_index(
    variant: Dict,
    use_gpu: bool = False,
) -> str:
    """
    Constructs a FAISS index (HNSW64 + inner product) using embeddings generated by `save_dataset_embeddings`.

    - Vectors are L2-normalized before applying the inner product, which is equivalent to cosine similarity.
    - If `use_gpu=True` and `faiss-gpu` is installed, the index is built on the GPU and then transferred to the CPU for saving.

    Args:
        variant (dict): A dictionary containing at least:
            - 'env'
            - 'dataset'
            - 'model_save_dir'
        use_gpu (bool): Whether to use GPU-enabled FAISS for index construction (the index is saved in CPU format).

    Returns:
        str: The file path of the saved index.
    """
    if faiss is None:
        raise ImportError(
            "faiss don't install `pip install faiss-cpu` or `pip install faiss-gpu`"
        )

    save_dir = variant["model_save_dir"]
    npz_path = os.path.join(
        save_dir,
        f"{variant['env']}_{variant['dataset']}_embeddings.npz",
    )
    if not os.path.exists(npz_path):
        raise FileNotFoundError(
            f"can't find embeddings : {npz_path},please run save_dataset_embeddings。"
        )
    data = np.load(npz_path)
    embs = data["embeddings"].astype("float32")  # [N, H]
    actions = data["actions"].astype("float32")  # [N, act_dim]
    rtgs = data["rtgs"].astype("float32")        
    N, H = embs.shape
    norms = np.linalg.norm(embs, axis=1, keepdims=True)
    norms = np.maximum(norms, 1e-12)  
    vectors = embs / norms
    vectors = np.ascontiguousarray(vectors, dtype="float32")

    metric = faiss.METRIC_INNER_PRODUCT
    index_cpu = faiss.index_factory(H, "HNSW64", metric)
    print("[faiss] index_cpu.is_trained =", index_cpu.is_trained)

    if use_gpu:
        res = faiss.StandardGpuResources()
        index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu)  
        index_gpu.add(vectors)
        index_to_save = faiss.index_gpu_to_cpu(index_gpu)
    else:
        index_cpu.add(vectors)
        index_to_save = index_cpu
    index_path = os.path.join(
        save_dir,
        f"{variant['env']}_{variant['dataset']}_faiss.index",
    )
    faiss.write_index(index_to_save, index_path)

    actions_path = os.path.join(
        save_dir,
        f"{variant['env']}_{variant['dataset']}_faiss_actions.npz",
    )
    rtgs_path = os.path.join(
        save_dir,
        f"{variant['env']}_{variant['dataset']}_faiss_rtgs.npz",
    )
    np.savez(actions_path, actions=actions)
    np.savez(rtgs_path, rtgs=rtgs)

    print(f"FAISS saved to: {index_path}")
    print(f"action saved to: {actions_path}")
    print(f"rtgs saved to: {rtgs_path}")

    return index_path


