import sys
import re
import copy

from copy import deepcopy
from collections import namedtuple
import pickle
import torch
import numpy as np
import gymnasium as gym
from torch.distributions import uniform

from datasets import load_dataset, Dataset
from arrl.utils import load_model 
from arrl.ddpg import DDPG, normalize
from arrl.main import load_env_name

Trajectory = namedtuple("Trajectory", ["obs", "actions", "rewards", "infos", "policy_infos"])

def get_hf_offline_data(dir_prefix, used_length=None):
    raw_trajs = Dataset.from_file(dir_prefix + "/data-00000-of-00001.arrow", split='train')

    trajs = []
    print("HF Offline Dataset loading:", dir_prefix)
    dones_list = []
    for raw_traj in raw_trajs:
        
        if 'adv_actions' not in raw_traj:
            raw_traj['observations'] = raw_traj.pop('state')
            raw_traj['pr_actions'] = raw_traj.pop('pr_action')
            raw_traj['adv_actions'] = raw_traj.pop('adv_action')
            raw_traj['dones'] = raw_traj.pop('done')
            raw_traj['rewards'] = raw_traj.pop('reward')

        infos_ = [{'adv': adv_act} for adv_act in raw_traj['adv_actions']]
        length = len(infos_) if not used_length else used_length
        dones_list.append(np.sum(raw_traj['dones']))
        trajs.append(Trajectory(obs=raw_traj['observations'][:length], 
                                actions=raw_traj['pr_actions'][:length], 
                                rewards=np.array(raw_traj['rewards'][:length]),
                                infos=infos_[:length], 
                                policy_infos=[]))
    return trajs


def load_mujoco_env(env_name, 
                    data_name=None, 
                    used_length=1000, 
                    device="cpu", 
                    adv_model_path='', 
                    dir_prefix='', 
                    added_dir_prefix=None,
                    added_data_prop=1.0, 
                    env_alpha=0.1):
    
    basic_env = gym.make(load_env_name(env_name))
    basic_bm = copy.deepcopy(basic_env.env.env.model.body_mass.copy())
    env = AdvGymEnv(basic_env, adv_model_path, device, env_name, env_alpha, basic_bm)
    trajs = get_hf_offline_data(dir_prefix, used_length)

    added_trajs = []
    if added_dir_prefix:
        with open(added_dir_prefix, 'rb') as file:
            added_trajs = pickle.load(file)

    if added_data_prop == -1:
        return env, trajs + added_trajs
    elif added_data_prop > 1:
        return env, trajs[:int(len(added_trajs) / added_data_prop)] + added_trajs
    else:
        return env, trajs + added_trajs[:int(len(trajs) * added_data_prop)]


class OfflineEnv:
    def __init__(self, trajs) -> None:
        self.trajs = trajs


class AdvGymEnv:
    def __init__(self, env, adv_model_path, device, env_name, env_alpha=0.1, basic_bm=None):
        self.env = env
        self.env_name = env_name
        self.adv_action_space = env.action_space
        self.adv_model_path = adv_model_path
        self.basic_bm = basic_bm
        self.reset_model_rl(adv_model_path, device)
        self.current_state = None
        self.t = 0
        self.env_alpha = env_alpha

    def reset(self):
        self.t = 0
        state, _ = self.env.reset()
        self.current_state = state
        return state
    
    def reset_model_rl(self, adv_model_path, device):
        print("Reset adversary:", adv_model_path)
        self.adv_model = None
        if 'env' in adv_model_path:
            mass = eval(re.findall(r'\d+\.\d+|\d+', adv_model_path)[0])
            for idx in range(len(self.basic_bm)):
                self.env.env.env.model.body_mass[idx] = self.basic_bm[idx] * mass
        elif adv_model_path != 'zero':
            adv_model_path = adv_model_path.replace(self.env_name, load_env_name(self.env_name))
            agent = DDPG(gamma=1, 
                        tau=1, 
                        hidden_size=64, 
                        num_inputs=self.env.observation_space.shape[0],
                        action_space=self.env.action_space.shape[0], 
                        train_mode=False, 
                        alpha=1, 
                        replay_size=1, 
                        device=device)
            
            load_model(agent, basedir=adv_model_path)
            self.adv_model = agent

    def step(self, pr_action, adv_action=None):
        if (adv_action is None) and (self.adv_model_path != 'zero') and (self.adv_model is not None):
            state = torch.from_numpy(self.current_state).to(self.adv_model.device, dtype=torch.float32)
            adv_action = self.adv_model.adversary(state).data.clamp(-1, 1).cpu().numpy()

            state, reward, done, _, _ = self.env.step(pr_action * (1 - self.env_alpha) + adv_action * self.env_alpha)

        else:
            adv_action = np.zeros_like(pr_action)
            state, reward, done, _, _ = self.env.step(pr_action)
        
        self.current_state = state
        self.t += 1
        return state, reward, done, {'adv_action': adv_action}

    def __getattr__(self, attr):
        if (attr not in dir(self)) and (attr != 'reset') and (attr != 'step'):
            return self.env.__getattribute__(attr)
        return self.__getattribute__(attr)


def collect_random_data(env_name, device, adv_model_path, data_path, horizon=1000):
    actual_adv_path = adv_model_path.replace(env_name, load_env_name(env_name))
    env = gym.make(load_env_name(env_name))

    agent = DDPG(gamma=1, 
                tau=1, 
                hidden_size=64, 
                num_inputs=env.observation_space.shape[0],
                action_space=env.action_space.shape[0], 
                train_mode=False, 
                alpha=1, 
                replay_size=1, 
                device=device)
        
    load_model(agent, basedir=actual_adv_path)
    agent.eval()
    noise = uniform.Uniform(torch.tensor([-1.0], dtype=torch.float32), torch.tensor([1.0], dtype=torch.float32))

    n_interactions = 1e6
    trajs = []
    n_gathered = 0

    obs_ = []
    actions_ = []
    rewards_ = []
    infos_ = []
    policy_infos_ = []
    t = 0

    obs, _ = env.reset()
    reward = None

    with torch.no_grad():
        while n_gathered < n_interactions:
            if n_gathered % 1e4 == 0:
                print(f"\r{n_gathered} steps done", end='')
            obs_.append(deepcopy(obs))

            state = torch.from_numpy(obs).to(device, dtype=torch.float32)

            pr_action = agent.actor(state).clamp(-1, 1).cpu()
            if np.random.random() < 0.1:
                pr_action = noise.sample(pr_action.shape).view(pr_action.shape).cpu()

            adv_action = agent.adversary(state).clamp(-1, 1).cpu()
            if np.random.random() < 0.1:
                adv_action = noise.sample(adv_action.shape).view(adv_action.shape).cpu()

            step_action = (pr_action * 0.9 + adv_action * 0.1).data.clamp(-1, 1).numpy()

            policy_infos_.append({})
            actions_.append(pr_action.numpy())

            obs, reward, done, info, _ = env.step(step_action)

            t += 1
            infos_.append({"adv": adv_action.numpy()})
            rewards_.append(reward)

            n_gathered += 1

            if t == horizon or done:
                trajs.append(Trajectory(obs=obs_,
                                        actions=actions_,
                                        rewards=rewards_,
                                        infos=infos_,
                                        policy_infos=policy_infos_))
                t = 0
                obs_ = []
                actions_ = []
                rewards_ = []
                infos_ = []
                policy_infos_ = []

                obs, _ = env.reset()
                reward = None    

    with open(data_path, 'wb') as file:
        pickle.dump(trajs, file)
        print('Saved trajectories to dataset file', data_path)  
    
    return AdvGymEnv(gym.make(load_env_name(env_name)), actual_adv_path, device), trajs

    