# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
import numpy as np
import torch
import time


class CQLTrainer:
    def __init__(self, model, get_prompt_batch):
        self.model = model
        self.diagnostics = dict()
        self.get_prompt_batch = get_prompt_batch
        self.start_time = time.time()

    def train_step_mix(self):
        batch = self.get_prompt_batch()
        states, actions, rewards, dones, next_states = batch
        
        batch_dict = {
            "rewards": rewards,
            "terminals": dones,
            "observations": states,
            "actions": actions,
            "next_observations": next_states,
        }
        
        loss, metrics = self.model.train_from_torch(batch_dict)
        
        return loss.detach().cpu().item(), metrics

    def pure_train_iteration_mix(self, num_steps):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss, metrics = self.train_step_mix()
            train_losses.append(train_loss)

        logs['time/training'] = time.time() - train_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        logs.update(metrics)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        return logs

    def eval_iteration_multienv(self, env_name, env, info, n_episodes):
        self.model.eval()
        
        ep_lengths = 0
        ep_returns = 0
        with torch.no_grad():
            for e in range(n_episodes):
                ep_return, infos = eval_episode_cql(
                    env=env,
                    model=self.model,
                    max_ep_len=info["max_ep_len"],
                    state_mean=info["state_mean"],
                    state_std=info["state_std"],
                )
                ep_lengths += infos["eval/episode_length"]
                ep_returns += ep_return

        ep_returns = ep_returns / n_episodes
        ep_lengths = ep_lengths / n_episodes

        print('evaluate at tasks: ', env_name, f"returns: {ep_returns:.2f}, lengths: {ep_lengths:.2f}")

        return {
            "eval/episode_length": ep_lengths,
            "eval/episode_return": ep_returns
        }

    def save_model(self, env_name, postfix, folder):
        model_name = '/CQL-' + env_name + postfix
        torch.save(self.model.state_dict(), folder+model_name)  # model save
        print('model saved to ', folder+model_name)


def get_batch_cql(trajectories, batch_size, info):
    num_trajectories, p_sample, sorted_inds = info['num_trajectories'], info['p_sample'], info['sorted_inds']
    max_ep_len, state_mean, state_std, scale = info['max_ep_len'], info['state_mean'], info['state_std'], info['scale']
    state_dim, act_dim, device = info['state_dim'], info['act_dim'], info['device']

    def fn(batch_size=batch_size):
        batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
        )

        s, a, r, d, sn = [], [], [], [], []
        for i in range(batch_size):
            traj = trajectories[int(sorted_inds[batch_inds[i]])]
            si = np.random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            s.append(traj['observations'][si].reshape(1, -1, state_dim))
            a.append(traj['actions'][si].reshape(1, -1, act_dim))
            r.append(traj['rewards'][si].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si].reshape(1, -1))
            else:
                d.append(traj['dones'][si].reshape(1, -1))
            sn.append(traj['next_observations'][si].reshape(1, -1, state_dim))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, state_dim)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, act_dim)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, 1)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device).reshape(batch_size, 1)
        sn = torch.from_numpy(np.concatenate(sn, axis=0)).to(dtype=torch.float32, device=device).reshape(batch_size, state_dim)

        s = (s - torch.from_numpy(state_mean).to(device)) / torch.from_numpy(state_std).to(device)

        return s, a, r, d, sn

    return fn


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum


def eval_episode_cql(
        env,
        model,
        max_ep_len=1000,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        mode='normal',
        no_state_normalize=False
):
    model.eval()
    model.to(device=device)

    state_mean = torch.as_tensor(state_mean, device=device, dtype=torch.float32).flatten()
    state_std = torch.as_tensor(state_std, device=device, dtype=torch.float32).flatten()

    # state = env.reset(seed=0)
    state = env.reset()
    meta_world = False
    if type(state) is tuple:
        if type(state[0]) is np.ndarray and type(state[1]) is dict:
            state = state[0]
            meta_world = True
    if mode == 'noise':
        state = state + np.random.normal(0, 0.1, size=state.shape)
    state = torch.from_numpy(state).to(device=device, dtype=torch.float32)

    episode_return, episode_length = 0, 0
    for t in range(max_ep_len):
        if no_state_normalize:
            action = model._get_policy_actions(state.unsqueeze(0).to(torch.float32), 1)
        else:
            action, _, _, _ = model._get_policy_actions(((state - state_mean) / state_std).unsqueeze(0).to(torch.float32), 1)

        action = action.squeeze().detach().cpu().numpy()
        action = np.clip(action, env.action_space.low, env.action_space.high)

        if meta_world:
            new_state, reward, done, trunc, infos = env.step(action)
            done = done or trunc
        else:
            new_state, reward, done, infos = env.step(action)

        state = torch.from_numpy(new_state).to(device=device, dtype=torch.float32)

        episode_return += reward
        episode_length += 1

        if done:
            break

    infos = {
        'eval/episode_length': episode_length,
        'eval/episode_return': episode_return
    }
    return episode_return, infos
