import numpy as np
import torch
import gym
import d4rl
import argparse
import os
import random
import copy
from pathlib import Path
import yaml
import h5py

import algo.utils as utils
from envs.common import call_env

import ott
import scipy as sp

# import ot
import jax.numpy as jnp
import numpy as np
import jax
import gc

def solve_robust_ot(
    src_data, 
    tar_data, 
    epsilon=0.05,      
    lambda_src=5.0,    
    lambda_tar=0.1     
):
    src_B = src_data.shape[0]
    tgt_B = tar_data.shape[0]
    src_embs = jnp.array(src_data.reshape(src_B, -1), dtype=jnp.float16)  
    tgt_embs = jnp.array(tar_data.reshape(tgt_B, -1), dtype=jnp.float16)  
    cost_fn = ott.geometry.costs.Euclidean()
    scale_cost = 'max_cost'
    tau_a = lambda_src / (lambda_src + epsilon)
    tau_b = lambda_tar / (lambda_tar + epsilon)
    geom = ott.geometry.pointcloud.PointCloud(src_embs, tgt_embs, cost_fn=cost_fn, epsilon=epsilon, scale_cost=scale_cost)
    prob = ott.problems.linear.linear_problem.LinearProblem(geom,tau_a=tau_a,tau_b=tau_b)
    solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=1e-5, max_iterations=1000)
    sinkhorn_output = solver(prob)
    source_weights = sinkhorn_output.marginal(1)
    source_weights = jnp.nan_to_num(source_weights)
    return source_weights

def filter_dataset(src_replay_buffer, tar_replay_buffer, args):
    src_num = src_replay_buffer.state.shape[0]
    srcdata = np.hstack([src_replay_buffer.state, src_replay_buffer.action, src_replay_buffer.next_state])
    tar_num = tar_replay_buffer.state.shape[0]
    tardata = np.hstack([tar_replay_buffer.state, tar_replay_buffer.action, tar_replay_buffer.next_state])
    all_data = np.vstack([srcdata, tardata])
    mean = np.mean(all_data, axis=0)
    std = np.std(all_data, axis=0) + 1e-6
    srcdata_norm = (srcdata - mean) / std
    tardata_norm = (tardata - mean) / std
    src_for_ot= srcdata_norm
    tar_for_ot = tardata_norm
    weights_result = []
    from functools import partial
    batch_solve = jax.jit(partial(
        solve_robust_ot, 
        cost_type=args.metric,
        epsilon=args.epsilon,
        lambda_src=args.lambda_src,
        lambda_tar=args.lambda_tar
    ))
    
    batch_size = 10000 
    iter_time = int(np.ceil(src_num / batch_size))
    
    for i in range(iter_time):
        start_idx = batch_size * i
        end_idx = min(batch_size * (i + 1), src_num)
        src_batch = src_for_ot[start_idx:end_idx]
        weights_jax = batch_solve(src_batch, tar_for_ot)
        part_res = jax.device_get(weights_jax).tolist()
        weights_result.extend(part_res)
        if (i + 1) % 5 == 0:
            print(f'Processed {end_idx} / {src_num} transitions...')
            
    weights_result = np.array(weights_result)
    print("Cleaning up JAX memory...")
    del batch_solve
    del src_for_ot
    del tar_for_ot
    jax.clear_backends()
    gc.collect()
    return weights_result

def compute_and_save_weights(src_replay_buffer, tar_replay_buffer, save_path, args):
    import h5py
    weights = filter_dataset(src_replay_buffer, tar_replay_buffer, args)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    replay_dataset = dict(cost=weights)
    with h5py.File(save_path, 'w') as hfile:
        for k in replay_dataset:
            hfile.create_dataset(k, data=replay_dataset[k], compression='gzip')
    return weights
