from pathlib import Path

import gym
import d4rl
import numpy as np
import itertools
import os
import torch
from tqdm import trange

from pex.algorithms.pex import OUR_PQV
from pex.algorithms.iql_online import IQL_online
from pex.networks.policy import GaussianPolicy
from pex.networks.value_functions import DoubleCriticNetwork, ValueNetwork
from pex.utils.util import (
    set_seed, ReplayMemory, torchify, eval_policy, torchify, DEFAULT_DEVICE,
    get_batch_from_dataset_and_buffer, get_batch_from_dataset_and_buffer_ours_on, get_batch_from_dataset_and_buffer_ours_off,
    eval_policy, set_default_device, get_env_and_dataset)

 
import warnings
warnings.filterwarnings("ignore", category=Warning, message=".*?.*?.*?")
warnings.filterwarnings("ignore", category=UserWarning, message=".*?.*.*?")
warnings.filterwarnings("ignore", category=FutureWarning, message=".*?.*.*?")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*?.*.*?")

def main(args):
    torch.set_num_threads(4)
    os.makedirs(args.log_dir, exist_ok=True)


    env, dataset, reward_transformer = get_env_and_dataset(args.env, args.max_episode_steps)
    dataset_size = dataset['observations'].shape[0]
    obs_dim = dataset['observations'].shape[1]
    act_dim = dataset['actions'].shape[1]

    if args.seed is not None:
        set_seed(args.seed, env=env)

    if torch.cuda.is_available():
        set_default_device()

    action_space = env.action_space
    policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num, action_space=action_space, scale_distribution=False, state_dependent_std=False)


    double_buffer = True
    assert args.ckpt_path, "need to provide a valid checkpoint path"
    alg = OUR_PQV(
        critic=DoubleCriticNetwork(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num),
        vf=ValueNetwork(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num),
        policy=policy,
        optimizer_ctor=lambda params: torch.optim.Adam(params, lr=args.learning_rate),
        alpha=args.alpha_sac,
        tau=args.tau,
        beta=args.beta,
        target_update_rate=args.target_update_rate,
        discount=args.discount,
        ckpt_path=args.ckpt_path,
        inv_temperature=args.inv_temperature,
    )

    memory = ReplayMemory(args.replay_size, args.seed)

    total_numsteps = 0

    with open(os.path.join(args.log_dir, 'eval_results.csv'), 'w') as f:
        f.write("total_steps,eval_return,eval_return_std,normalized_return,normalized_return_std\n")

    #========================
    import functools
    from torch.utils.data import DataLoader
    from dataset.dataset import D4RL_dataset
    from diffusion_SDE.loss import loss_fn
    from diffusion_SDE.schedule import marginal_prob_std
    from diffusion_SDE.model import ScoreNet

    
    marginal_prob_std_fn = functools.partial(marginal_prob_std, device=args.device)
    args.marginal_prob_std_fn = marginal_prob_std_fn
    score_model= ScoreNet(input_dim=obs_dim+act_dim, output_dim=act_dim, marginal_prob_std=marginal_prob_std_fn, args=args).to(args.device)
    score_model.q[0].to(args.device)

    print("loading actor...")
    ckpt = torch.load(args.actor_load_path, map_location=args.device)
    score_model.load_state_dict(ckpt)
    score_model.q[0].guidance_scale = args.s
    dataset_dm = D4RL_dataset(args)
    generator = torch.Generator(device=args.device)
    data_loader = DataLoader(dataset_dm, batch_size=256, shuffle=True, generator=generator)     
    dataset_dm.fake_actions = torch.Tensor(np.load('./models_rl/'+args.env+'/actions{}_raw.npy'.format(args.diffusion_steps)).astype(np.float32)).to(args.device)

    def datas_():
        while True:
            yield from data_loader
    datas = datas_()
    #========================

    for i_episode in itertools.count(1):
        episode_reward = 0
        episode_steps = 0
        done = False
        state = env.reset()

        while not done:
            action = alg.select_action(torchify(state).to(DEFAULT_DEVICE)).detach().cpu().numpy()
            if len(memory) > args.initial_collection_steps:
                for i in range(args.updates_per_step):
                    batch = []
                    offline_batch = get_batch_from_dataset_and_buffer_ours_off(dataset, memory, args.batch_size, double_buffer)
                    online_batch  = get_batch_from_dataset_and_buffer_ours_on(dataset, memory, args.batch_size, double_buffer)
                    obs_off, act_off = offline_batch[0], offline_batch[1]
                    obs_on, act_on = online_batch[0], online_batch[1]

                    with torch.no_grad():
                        a_hat_off = score_model.select_actions_ours(obs_off)
                        a_hat_on  = score_model.select_actions_ours(obs_on)
                    adim = act_off.shape[1]
                    d_off = 0.5 * ((a_hat_off - act_off)**2).sum(dim=1).to(args.device)
                    d_on  = 0.5 * ((a_hat_on  - act_on )**2).sum(dim=1).to(args.device)

                    # ----- fit Gaussians -----
                    mu_off = d_off.mean()
                    mu_on = d_on.mean()
                    s_off = torch.clamp(d_off.std(unbiased=False), min=1e-12)
                    s_on = torch.clamp(d_on.std(unbiased=False), min=1e-12)

                    # ----- PDF intersection: solve N(mu_off,s_off^2) = N(mu_on,s_on^2) -----
                    A = (1.0/(s_on*s_on)) - (1.0/(s_off*s_off))

                    if torch.abs(A) < 1e-14:
                        tau = 0.5*(mu_off + mu_on)
                    else:
                        B = -2.0*(mu_on/(s_on*s_on) - mu_off/(s_off*s_off))
                        C = (mu_on*mu_on)/(s_on*s_on) - (mu_off*mu_off)/(s_off*s_off) \
                            - 2.0*(torch.log(torch.tensor(s_on/s_off))).float()
                        disc = B*B - 4*A*C
                        if disc < 0:
                            tau = 0.5*(mu_off + mu_on)
                        else:
                            sqrt_disc = disc**0.5
                            x1 = (-B + sqrt_disc)/(2*A)
                            x2 = (-B - sqrt_disc)/(2*A)
                            lo, hi = (mu_off, mu_on) if mu_off <= mu_on else (mu_on, mu_off)
                            between = [x for x in (x1, x2) if lo <= x <= hi]
                            if between:
                                mid = 0.5*(mu_off + mu_on)
                                tau = between[0] if torch.abs(between[0]-mid) <= torch.abs(between[-1]-mid) else between[-1]
                            else:
                                mid = 0.5*(mu_off + mu_on)
                                tau = mid

                    # ----- exchange pools -----
                    off2on_pool = torch.nonzero(d_off >= tau, as_tuple=False).squeeze(1)  # offline that look online
                    on2off_pool = torch.nonzero(d_on  <  tau, as_tuple=False).squeeze(1)  # online  that look offline

                    # balanced K, cap at 32 (≈25% of 128)
                    K = min(off2on_pool.numel(), on2off_pool.numel(), args.k)
                    # select most confident K from each pool
                    if K > 0:
                        off_sel = off2on_pool[torch.topk(d_off[off2on_pool], K, largest=True ).indices]
                        on_sel  = on2off_pool[torch.topk(d_on[on2off_pool],   K, largest=False).indices]
                    else:
                        off_sel = off2on_pool[:0]
                        on_sel  = on2off_pool[:0]
                        
                    # complements
                    def complement_idx(n, idx):
                        mask = torch.ones(n, dtype=torch.bool, device=args.device)
                        if idx.numel() > 0: mask[idx] = False
                        return torch.nonzero(mask, as_tuple=False).squeeze(1)

                    off_rem = complement_idx(d_off.shape[0], off_sel)
                    on_rem  = complement_idx(d_on.shape[0],  on_sel)
                
                    offline_new = [
                        torch.cat([off[off_rem], on[on_sel]], dim=0)
                        for off, on in zip(offline_batch, online_batch)
                    ]
                    online_new = [
                        torch.cat([on[on_rem],  off[off_sel]], dim=0)
                        for off, on in zip(offline_batch, online_batch)
                    ]
                    batch_exchanged = [
                        torch.cat([off_f, on_f], dim=0)
                        for off_f, on_f in zip(offline_new, online_new)
                    ]
                    alg.update(batch_exchanged[0], batch_exchanged[1], batch_exchanged[2], batch_exchanged[3], batch_exchanged[4])

            data = next(datas)
            s = data['s']
            fake_a = data['fake_a']
            B, N, _ = fake_a.shape
            s_flat = s.unsqueeze(1).expand(-1, N, -1).reshape(B * N, -1)
            a_flat = fake_a.reshape(B * N, -1)

            energy = alg.target_critic.min(s_flat, a_flat).view(B, N).detach()
            loss_energy = score_model.q[0].update_qt(data, energy)

            next_state, reward, done, _ = env.step(action)
            episode_steps += 1
            total_numsteps += 1
            episode_reward += reward

            reward_for_replay = reward_transformer(reward)


            terminal = 0 if episode_steps == env._max_episode_steps else float(done)
            memory.push(state, action, reward_for_replay, next_state, terminal)
            state = next_state

            if total_numsteps % args.eval_period == 0 and args.eval is True:

                print("Episode: {}, total env-steps: {}".format(i_episode, total_numsteps))
                eval_return, eval_return_std, normalized_return, normalized_return_std = eval_policy(env, args.env, alg, args.max_episode_steps, args.eval_episode_num)

                # Save evaluation results to csv
                with open(os.path.join(args.log_dir, 'eval_results.csv'), 'a') as f:
                    f.write("{},{},{},{},{}\n".format(total_numsteps, eval_return, eval_return_std, normalized_return, normalized_return_std))

            # Save the model
            if total_numsteps % (args.eval_period *5) == 0:
                torch.save(alg.state_dict(), args.log_dir + '/{}_online_ckpt_{}'.format(args.algorithm, total_numsteps))
                print("Model saved at step: {}".format(total_numsteps))

        if total_numsteps > args.total_env_steps:
            break


        env.close()

    torch.save(alg.state_dict(), args.log_dir + '/{}_online_ckpt'.format(args.algorithm))

def get_guidance_scale(env_name):
    guidance_scale_map = {
        # Locomotion-Medium
        'walker2d-medium-v2': 10.0,
        'halfcheetah-medium-v2': 10.0,
        'hopper-medium-v2': 8.0,

        # Locomotion-Medium-Expert
        'walker2d-medium-expert-v2': 5.0,
        'halfcheetah-medium-expert-v2': 3.0,
        'hopper-medium-expert-v2': 2.0,

        # Locomotion-Medium-Replay
        'walker2d-medium-replay-v2': 5.0,
        'halfcheetah-medium-replay-v2': 8.0,
        'hopper-medium-replay-v2': 3.0,

        # AntMaze-Fixed
        'antmaze-umaze-v2': 3.0,
        'antmaze-medium-play-v2': 4.0,
        'antmaze-large-play-v2': 3.0,  # Assuming from table

        # AntMaze-Diverse
        'antmaze-umaze-diverse-v2': 1.0,
        'antmaze-medium-diverse-v2': 3.0,
        'antmaze-large-diverse-v2': 2.0,  # Assuming from table
    }

    return guidance_scale_map.get(env_name, None)


if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--env', required=True)
    parser.add_argument('--log_dir', required=True)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--hidden_dim', type=int, default=256)
    parser.add_argument('--hidden_num', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--learning_rate', type=float, default=3e-4)
    parser.add_argument('--target_update_rate', type=float, default=0.005)
    parser.add_argument('--tau', type=float, default=0.7)
    parser.add_argument('--beta', type=float, default=3.0, help='IQL inverse temperature')
    parser.add_argument('--ckpt_path', default=None, help='path to the offline checkpoint')

    parser.add_argument('--replay_size', type=int, default=300000, metavar='N', help='size of replay buffer (default: 10000000)')
    parser.add_argument('--total_env_steps', type=int, default=120001, metavar='N', help='total number of env steps (default: 1000000)')
    parser.add_argument('--initial_collection_steps', type=int, default=5000, metavar='N', help='Initial environmental steps before training starts (default: 5000)')
    parser.add_argument('--updates_per_step', type=int, default=10, metavar='N', help='model updates per simulator step (default: 1)')
    parser.add_argument('--inv_temperature', type=float, default=3, metavar='G', help='inverse temperature for PEX action selection (default: 10)')
    parser.add_argument('--eval', type=bool, default=True, help='Evaluates a policy a policy every 10 episode (default: True)')
    parser.add_argument('--eval_period', type=int, default=6000)
    parser.add_argument('--eval_episode_num', type=int, default=10, help='Number of evaluation episodes (default: 10)')
    parser.add_argument('--max_episode_steps', type=int, default=1000)
    
    parser.add_argument("--device", default="cuda", type=str)      #
    parser.add_argument("--save_model", default=1, type=int)       #
    parser.add_argument('--alpha', type=float, default=3.0)        # beta parameter in the paper, use alpha because of legacy
    parser.add_argument('--actor_load_path', type=str, default='./models_rl/hopper-medium-expert-v2/behavior_ckpt600.pth')
    parser.add_argument('--diffusion_steps', type=int, default=15)
    parser.add_argument('--M', type=int, default=16)               # support action number
    parser.add_argument('--s', type=float, default=None)# guidance scale
    parser.add_argument('--method', type=str, default="CEP")
    parser.add_argument('--k', type=int, default=128)
    parser.add_argument('--alpha_sac', type=float, default=0.2)
    
    args = parser.parse_args()
    if "antmaze" not in args.env:
        args.M = 16
    else:
        args.M = 32
    if args.s is None:
        args.s = get_guidance_scale(args.env)
        if args.s is None:
            raise ValueError(f"No guidance scale defined for {args.env}")
    args.actor_load_path = args.actor_load_path.replace('hopper-medium-expert-v2', args.env)
    main(args)
