import os
import numpy as np
import click
import json
import random
import torch
import torch.optim as optim
from meta_test_algo.network import es_policy2,q_function
from rlkit.torch.networks import stochastic_actor2
from meta_test_algo.render_utils import initialize_viewer,render
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from configs.default import default_config
from reference_data.sampler import load_samples, load_transitions
import torch.nn as nn
import wandb
from tqdm import tqdm
import re

class launcher():
    def __init__(self,test_env,train_env,test_file,variant,
                 debug=False,
                 replay=True,
                 pretrained=True,
                 seed=0):
        self.env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
        if test_env == train_env:
            # self.env.set_test_task(variant['n_eval_tasks'])
            if test_env == 'cheetah-vel':
                self.env.set_velocity(-2.0) # set velocity (-2)
            elif test_env == 'cheetah-dir':
                self.env.set_direction(-1) # set direction (forward)
            elif test_env == 'ant-goal':
                self.env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 5)
            elif test_env == 'ant-dir':
                self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
            elif test_env =='humanoid-dir':
                self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
            elif 'params' in test_env:
                self.env.set_test_task()
        else:
            # self.env.set_train_task(variant['n_train_tasks'])
            if test_env == 'cheetah-vel':
                self.env.set_velocity(-2.0) # set velocity (-2)
            elif test_env == 'cheetah-dir':
                self.env.set_direction(-1) # set direction (backward)
            elif test_env == 'ant-goal':
                self.env.set_goal_position(0.5*np.pi,3) # set goal (angle = 0.5 pi, radius = 5)
            elif test_env == 'ant-dir':
                self.env.set_direction(0.5*np.pi) # set direction (angle = 1.5 pi)
            elif test_env =='humanoid-dir':
                self.env.set_direction(0.5*np.pi) # set direction (angle = 1.5 pi)
            elif 'params' in test_env:
                self.env.set_test_task()
        self.env_name = test_env
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.env.set_seed(seed)
        self.seed = seed

        self.tasks = self.env.get_all_task_idx()
        self.eval_tasks = list(self.tasks[-variant['n_eval_tasks']:])
        self.obs_dim = int(np.prod(self.env.observation_space.shape))
        self.action_dim = int(np.prod(self.env.action_space.shape))
        self.test_env = test_env
        self.train_env = train_env
        self.test_file = test_file

        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

        # meta test parameters
        self.method = variant['meta_test_params']['method']
        self.no_pretrained = variant['meta_test_params']['no_pretrained']

        self.total_step = variant['meta_test_params']['total_step']
        self.eval_step = variant['sr_params']['eval_step']

        self.max_path_length = variant['meta_test_params']['max_path_length']
        # instantiate networks
        self.latent_action_dim = variant['sr_params']['latent_action_dim']
        self.net_size = variant['net_size']

        self.w = variant['esq_params']['w']

        self.replay = replay
        self.pretrained = pretrained

        self.policy = es_policy2(self.obs_dim,
                                 self.action_dim,
                                 self.net_size,
                                 latent_dim=self.latent_action_dim,
                                 w = self.w).to(self.device)
    
        
        if self.pretrained:
            self.shared_param = self.load_shared_policy(self.test_file)
            self.policy.shared_layer.load_state_dict(self.shared_param)
            for param in self.policy.shared_layer.parameters():
                param.requires_grad = False
            self.parameters = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        else:
            for param in self.policy.shared_layer.parameters():
                param.requires_grad = True
            self.parameters = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters()) + list(self.policy.shared_layer.parameters())
        
        
        if test_env == 'walker-rand-params':
            self.optimizer = optim.Adam(self.parameters,lr=5e-4)
        elif test_env == 'hopper-rand-params':
            self.optimizer = optim.Adam(self.parameters,lr=1e-4)
        else:
            self.optimizer = optim.Adam(self.parameters,lr=1e-3)

        self.parameters_vector = nn.utils.parameters_to_vector(self.parameters).detach().cpu().numpy()
        self.param_shape = self.parameters_vector.shape


        self.batch_size = 2048
        self.obs, self.actions, self.rewards, self.next_obs, self.dones = load_transitions(variant['env_name'],
                                                                                           self.obs_dim,
                                                                                           self.action_dim,
                                                                                           self.replay)
        # self.obs, self.actions = self.get_batch(batch_size=5000)

        self.debug = debug
        
        self.q1 = q_function(self.obs_dim+self.action_dim,self.net_size).to(self.device)
        self.q2 = q_function(self.obs_dim+self.action_dim,self.net_size).to(self.device)
        self.target_q1 = q_function(self.obs_dim+self.action_dim,self.net_size).to(self.device)
        self.target_q2 = q_function(self.obs_dim+self.action_dim,self.net_size).to(self.device)
        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())
        self.q_optimizer = optim.Adam(list(self.q1.parameters())+
                                      list(self.q2.parameters()),
                                      lr=1e-4)

        if not self.debug:
            wandb.init(
                project = 'BC Test',
                name = f'[Replay: {self.replay}, Pretrained: {self.pretrained}]({self.train_env}->{self.test_env} test)'
            )


        self.sigma_vec = np.full_like(self.parameters_vector, 1e-3, dtype=np.float64)


    def load_shared_policy(self,epoch=None):
        if epoch==None:
            path = './shared_policy'
            pattern = re.compile(r'shared_policy\((\d+)\)\.pt')
            numbers = []
            for filename in os.listdir(path):
                match = pattern.fullmatch(filename)
                if match:
                    number = int(match.group(1))
                    numbers.append(number)
            if numbers:
                max_number = max(numbers)
            else:
                raise Exception("There is no file to read.")
            epoch = max_number
        if self.train_env == 'cheetah-dir':
            shared_policy_params = torch.load(f'./shared_policy/cheetah-vel/shared_policy({epoch}).pt')
        else:
            shared_policy_params = torch.load(f'./shared_policy/{self.train_env}/shared_policy({epoch}).pt')
        print('Shared policy ({}) is loaded.'.format(epoch))
        return shared_policy_params

    
    def update_params(self,param_vector):
        nn.utils.vector_to_parameters(
            torch.tensor(param_vector, dtype=torch.float32).to(self.device),
            self.parameters
        )


    def supervised_policy_update(self,transitions,es=True):
        obs, actions, rewards, next_obs, dones = transitions
        obs = torch.Tensor(obs).to(self.device)
        actions = torch.Tensor(actions).to(self.device)
        if es:
            with torch.no_grad():
                pairs = []
                F_list = []
                c_cov = 0.5
                base = self.parameters_vector.copy()
                grad_log_sigma = np.zeros_like(self.sigma_vec,dtype=np.float64)

                for i in range(1024):
                    epsilon = np.random.randn(*self.param_shape)
                    delta = self.sigma_vec * epsilon

                    # positive perturbation
                    perturbed_params = base + delta
                    self.update_params(perturbed_params)
                    pred_actions = self.policy.grad_action(obs)
                    positive_bc_loss = ((pred_actions - actions)**2).mean()
                    # F_pos = torch.exp(-alpha*positive_bc_loss).cpu().item()
                    F_pos = -positive_bc_loss.cpu().item()

                    # negative perturbation
                    perturbed_params = base - delta
                    self.update_params(perturbed_params)
                    pred_actions = self.policy.grad_action(obs)
                    negative_bc_loss = ((pred_actions - actions)**2).mean()
                    # F_neg = torch.exp(-alpha*negative_bc_loss).cpu().item()
                    F_neg = -negative_bc_loss.cpu().item()

                    pairs.append((epsilon,F_pos,F_neg))
                    F_list.append([F_pos,F_neg])

                F_arr = np.array(F_list,dtype=np.float64)
                mean, std = F_arr.mean(),F_arr.std() + 1e-8

                grad = np.zeros_like(self.parameters_vector,dtype=np.float64)
                for eps, F_pos, F_neg in pairs:
                    up = (F_pos - mean) / std
                    un = (F_neg - mean) / std

                    util_diff = 0.5*(up-un)
                    grad += util_diff * (eps/(self.sigma_vec+1e-12))

                    util_sum = 0.5*(up+un)
                    grad_log_sigma += util_sum * (eps**2 - 1.0)
                
                grad /= len(pairs)
                grad_log_sigma /= len(pairs)

                log_sigma = np.log(self.sigma_vec + 1e-12) + c_cov * grad_log_sigma
                self.sigma_vec = np.clip(np.exp(log_sigma),1e-6,1e-2)
                new_params = base + 0.00001*grad
                self.parameters_vector = new_params
                self.update_params(self.parameters_vector)

                pred_actions = self.policy.grad_action(obs)
                bc_loss = ((pred_actions - actions)**2).sum(axis=1).mean()
            return bc_loss.item()
        else:
            pred_actions = self.policy.grad_action(obs)
            bc_loss = ((pred_actions - actions)**2).sum(axis=1).mean()
            self.optimizer.zero_grad()
            bc_loss.backward()
            self.optimizer.step()
            return bc_loss.item()
        

    def get_transitions_minibatch(self):
        batch_size = min(len(self.obs),self.batch_size)
        indices = np.random.choice(len(self.obs), size=batch_size, replace=False)
        obs = torch.Tensor(self.obs[indices]).to(self.device)
        actions = torch.Tensor(self.actions[indices]).to(self.device)
        rewards = torch.Tensor(self.rewards[indices]).unsqueeze(dim=1).to(self.device) if not self.replay else None
        next_obs = torch.Tensor(self.next_obs[indices]).to(self.device) if not self.replay else None
        dones = torch.Tensor(self.dones[indices]).unsqueeze(dim=1).to(self.device) if not self.replay else None
        return obs, actions, rewards, next_obs, dones

    def get_minibatch(self,batch_size=np.inf):
        batch_size = min(len(self.obs),self.batch_size,batch_size)
        indices = np.random.choice(len(self.obs), size=batch_size, replace=False)
        obs = torch.Tensor(self.obs[indices]).to(self.device)
        actions = torch.Tensor(self.actions[indices]).to(self.device)
        return obs, actions
    
    def get_batch(self,batch_size=10000):
        batch_size = min(len(self.obs),batch_size)
        np.random.seed(42)
        indices = np.random.choice(len(self.obs), size=batch_size, replace=False)
        obs = self.obs[indices]
        actions = self.actions[indices]
        return obs, actions
    
    def learn(self):
        ES = False
        best_mean_ret = -np.inf
        temp = [-np.inf,-np.inf]
        if self.env_name == 'cheetah-vel':
            iter = 50000
        else:
            iter = 30000
        for i in range(iter):
            transitions = self.get_transitions_minibatch()
            bc_loss = self.supervised_policy_update(transitions,ES)
            if i % 10 == 0:
                ret = self.eval()
                temp.append(ret)
                if len(temp) > 2:
                    temp.pop(0)
                mean_ret = np.mean(temp)
                if mean_ret > best_mean_ret:
                    best_mean_ret = mean_ret
                    if not self.debug:
                        if self.pretrained:
                            torch.save(self.policy.state_dict(),
                                       f'./reference_data/{self.test_env}/bc_policy_proposed({self.seed}).pt')
                        else:
                            torch.save(self.policy.state_dict(),
                                       f'./reference_data/{self.test_env}/bc_policy_original({self.seed}).pt')
                print(f'[{i}]','Loss:',bc_loss,'Return:',ret, 'N_Data: ',len(self.obs))
                if not self.debug:
                    wandb.log({"BC Loss":bc_loss},step=i+1)
                    wandb.log({"Returns":ret},step=i+1)
        if not self.debug:
            wandb.finish()

    def eval(self):
        with torch.no_grad():
            returns = []
            for _ in range(10):
                obs = self.env.reset()
                env_step = 0
                episode_return = 0
                while env_step < self.max_path_length:
                    env_step += 1
                    obs = torch.Tensor(obs).to(self.device)
                    action = self.policy(obs).cpu().numpy()
                    next_obs, reward, done, env_info = self.env.step(action)
                    episode_return += reward
                    obs = next_obs
                    if done:
                        break
                # print(episode_return)
                returns.append(episode_return)
        return sum(returns)/len(returns)
    

def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--test_env',default=None)
@click.option('--train_env',default=None)
@click.option('--file_number',default=None)
@click.option('--debug', is_flag=True, default=False)
@click.option('--pretrained', is_flag=True, default=False)
@click.option('--replay', is_flag=True, default=False)
@click.option('--seed', default=0, type=int)

def main(test_env,train_env,file_number,debug,pretrained,replay,seed):
    variant = default_config
    if test_env == 'cheetah-dir':
        config = './configs/cheetah-dir.json'
    if test_env == 'cheetah-vel':
        config = './configs/cheetah-vel.json'
    elif test_env == 'ant-goal':
        config = './configs/ant-goal.json'
    elif test_env == 'ant-dir':
        config = './configs/ant-dir.json'
    elif test_env == 'humanoid-dir':
        config = './configs/humanoid-dir.json'
    elif test_env == 'walker-rand-params':
        config = './configs/walker_rand_params.json'
    elif test_env == 'hopper-rand-params':
        config = './configs/hopper_rand_params.json'
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    
    if train_env == None:
        train_env = test_env

    pattern = re.compile(r'shared_policy\((\d+)\)\.pt')
    numbers = []
    if test_env == 'cheetah-dir':
        path = f'./shared_policy/cheetah-vel'
    else:
        path = f'./shared_policy/{train_env}'

    for filename in os.listdir(path):
        match = pattern.fullmatch(filename)
        if match:
            number = int(match.group(1))
            numbers.append(number)
    if numbers:
        max_number = max(numbers)
    else:
        raise Exception("There is no file to read.")
    if file_number == None:
        test_file = max_number
    else:
        test_file = file_number

    launch = launcher(test_env,train_env,test_file,variant,
                      debug=debug,
                      pretrained=pretrained,
                      replay=replay,
                      seed=seed)
    launch.learn()

if __name__ == "__main__":
    # set seed
    main()