# Imports
import gym
import random
import numpy as np
import copy
import time, os
import torch
import torch.nn as nn
import torch.nn.functional as F

from rl_algos.single_agent.TD3.agent import Agent as TD3_Agent
from rl_algos.single_agent.Revalued.action_agent import Agent as Revalued_Agent

from utils.misc import get_dataset
from utils.decomp_networks import VectorizedLinear


class BaseInverseModel(nn.Module):
    def __init__(self, model_type, device):
        super(BaseInverseModel, self).__init__()

        self.model_type = model_type
        self.device = device

        self.total_it = 0
	
    
    def learn(self, agent, samples=None):
        # Training #

        replay_buffer = agent.dataset_buffer if agent.algo_name in ['revalued_action','decqn_action'] else agent.replay_buffer

        batch_size = 1024
        self.total_it += 1

        if replay_buffer.mem_cntr < batch_size:
            return

        if samples == None:
            *samples, batch_idx = replay_buffer.sample(rng=agent.rng, 
                                                        batch_size=batch_size)

        states, next_states, diff_states, actions, rewards, done_batch = samples

        if self.model_type == 'diff_state':
            model_input = diff_states
        elif self.model_type == 'cont_diff_state':
            model_input = next_states - states
        elif self.model_type == 'state':
            model_input = next_states

        action_pred = self(states, model_input)
        total_loss = self.calc_loss(action_pred, actions)

        self.optimiser.zero_grad()
        total_loss.backward()
        self.optimiser.step()

        return total_loss.item()
    
    def train_inverse(self, agent, train_iter=100000, save_num=1, save_model=True):

        for i in range(train_iter):
            inv_train_loss = self.learn(agent)
            if i%10000 == 0:
                print(f'Iteration {i}')
                print(f"Inverse model training loss: {inv_train_loss}")

        if save_model: 
            self.save_model(agent.env_id,save_num)

        print(f'Iteration {i}')
        print(f"Inverse model training loss: {inv_train_loss}")
        return inv_train_loss


    def gen_file_path(self, env_id, iter_no, create=False):

        path_list = ['models',f'iss_{self.model_type}',env_id]
        file_path = ''
        for path in path_list:
            file_path = os.path.join(file_path, path)
            if create and not os.path.exists(file_path):
                os.makedirs(file_path)
        file_name = f'model-{iter_no}'

        model_path = os.path.join(file_path,file_name)

        return model_path

    def save_model(self, env_id, iter_no):

        model_path = self.gen_file_path(env_id,iter_no, create=True)

        print(f"\nSaving models to {model_path}...")
        torch.save({'model_state_dict':self.state_dict()},
                    model_path)


    def load_model(self, env_id, iter_no):

        model_path = self.gen_file_path(env_id,iter_no, create=False)

        model_checkpoint  = torch.load(model_path, map_location=self.device)

        print(f"\nLoading models from {model_path}...")

        self.load_state_dict(model_checkpoint['model_state_dict'])


class DiscreteInverseModel(BaseInverseModel):

    def __init__(self, state_dim, action_dim, action_bins, model_type, hidden_dim=512, lr=1e-3,
            device='cpu'):



        super(DiscreteInverseModel, self).__init__(model_type=model_type,device=device)

        self.action_model = nn.Sequential(
                                        nn.Linear(2*state_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        VectorizedLinear(hidden_dim, action_bins, action_dim)
                                        )

        self.optimiser = torch.optim.Adam(self.parameters(), lr=lr)
        self.to(device)

    def forward(self, state, next_state):
        a = torch.cat([state, next_state], -1)

        action_logits = self.action_model(a.unsqueeze(0)).transpose(0,1)

        return action_logits

    def calc_loss(self, action_pred, action):

        action_pred = action_pred.flatten(0,1)
        action = action.flatten(0,1)

        return F.cross_entropy(action_pred, action)


class InverseModel(BaseInverseModel):
    def __init__(self, state_dim, action_dim, max_action, model_type, hidden_dim=512, lr=1e-3,
                    device='cpu'):

        super(InverseModel, self).__init__(model_type=model_type, device=device)


        self.action_model = nn.Sequential(
                                        nn.Linear(2*state_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, action_dim),
                                        )
        self.max_action = max_action

        self.optimiser = torch.optim.Adam(self.parameters(), lr=lr)
        self.to(device)

    def forward(self, state, next_state):
        a = torch.cat([state, next_state], -1)
        action = torch.tanh(self.action_model(a))*self.max_action

        return action

    def calc_loss(self, action_pred, action):

        return F.l1_loss(action_pred, action)


def run_inverse_model(config_dict,model_type=''):
    # Load environment

    config_dict['seed'] = 0

    env = config_dict['env']
    dataset = get_dataset(env, config_dict)

    lr_info = {'critic_lr':5e-4, 
                'actor_lr':5e-4, 
                'tau':1e-3, 
                }

    config_dict.update(lr_info)

    config_dict['offline'] = False
    config_dict['algo_type'] = 'online'

    ensemble_num = 1
    config_dict['mem_size'] = 4000000 

    config_dict['normalise_state'] = True
    config_dict['use_data'] = True

    if config_dict['dm_suite']:
        config_dict['critic_lr'] = 5e-4
        config_dict['online_steps'] = 2000001 

        config_dict['critic_factor'] = 10
        config_dict['use_data'] = False
        config_dict['burn_in_steps']= 10000
        config_dict['update_ratio'] =  1 
        config_dict['sample_type'] = 'double_q'
        config_dict['n_steps'] = 3
        config_dict['replay_mem_size'] = 1000000
        config_dict['dataset_mem_size'] = 5000000

        if 'dog' in config_dict['env_id']:
            config_dict['replay_mem_size'] = 250000
            config_dict['dataset_mem_size'] = 13000000
            config_dict['online_steps'] = 10000001
            config_dict['update_ratio'] = 10
        elif 'humanoid' in config_dict['env_id']:
            config_dict['replay_mem_size'] = 500000
            config_dict['dataset_mem_size'] = 10000000
            config_dict['online_steps'] = 6000001
            config_dict['update_ratio'] = 10

    else:
        config_dict['online_steps'] = 500001
        config_dict['critic_factor'] = 2
        config_dict['burn_in_steps']= 300000

    config_dict['critic_ensemble_num'] = ensemble_num
    config_dict['actor_ensemble_num'] = ensemble_num
        

    state_dims = env.observation_space.shape[0]
    action_dims = env.action_space.shape[0]

    
    if config_dict['dm_suite']:


        agent = Revalued_Agent(obs_dims=state_dims,
                                action_dims=action_dims,
                                dataset=dataset,
                                **config_dict)

        inv_ss = DiscreteInverseModel(state_dim=state_dims, action_dim=action_dims, action_bins=config_dict['action_bins'], 
                                        model_type=model_type).to(agent.device)
    else:
        agent = TD3_Agent(obs_dims=state_dims,
                      action_dims=action_dims,
                      dataset=dataset,
                      **config_dict
                      )

        max_action = agent.max_action_val
        min_action = agent.min_action_val


        inv_ss = InverseModel(state_dim=state_dims, action_dim=action_dims, max_action=max_action, model_type=model_type).to(agent.device)


    try:
        if agent.algo_name in ['revalued_action']:
            agent.dataset_buffer.load_buffer(config_dict)
        else:
            agent.replay_buffer.load_buffer(config_dict)
    except:

        init_steps = 0


        env.reset(seed=config_dict['seed'])

        n_step_buffer = []

        while init_steps<config_dict['burn_in_steps']:

            obs = env.reset()[0]

            if config_dict['normalise_state']:
                obs = (obs- agent.replay_buffer.mean)/agent.replay_buffer.std
            else:
                obs = obs[np.newaxis]

            done = False

            while not done:
                init_steps += 1
                obs = obs[np.newaxis,np.newaxis]
                act = env.action_space.sample()
                next_obs, reward, terminate, trunc, info = env.step(act.squeeze())
                done = terminate | trunc

                if config_dict['normalise_state']:
                    next_obs = (next_obs - agent.replay_buffer.mean)/agent.replay_buffer.std


                if agent.algo_name not in ['revalued_action', 'decqn_action']:
                    agent.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    agent.dataset_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                    n_step_buffer.append((obs,act,reward))
                    
                    if agent.n_steps == 1:
                        agent.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                    else:
                        if len(n_step_buffer) == agent.n_steps:
                            state_0, action_0, _ =  n_step_buffer[0]
                            disc_returns = np.sum([r * agent.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                            agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                true_next_state=n_step_buffer[1][0])
                            n_step_buffer.pop(0)

                    if done:
                        while n_step_buffer:
                            state_0, action_0, _ = n_step_buffer[0]
                            disc_returns = np.sum([r * agent.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                            try:
                                agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                    true_next_state=n_step_buffer[1][0])
                            except IndexError:
                                agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                            n_step_buffer.pop(0)


                obs = next_obs


        print(init_steps)
        grad_steps = 0


        while config_dict['online_steps'] > grad_steps:

            total_reward = 0
            done = False

            obs = env.reset()[0]
            
            if config_dict['normalise_state']:
                obs = (obs - agent.replay_buffer.mean)/agent.replay_buffer.std


            while not done:
                grad_steps += 1
                obs = obs[np.newaxis,np.newaxis]

                act = agent.choose_action(obs, deterministic=True, transform=True)['action']
                act = act.cpu().detach().numpy()

                if not config_dict['dm_suite']:
                    act += np.random.normal(scale=0.1,size=act.shape)

                next_obs, reward, terminal, trunc, info = env.step(act.squeeze())
                done = terminal | trunc
                total_reward += reward

                if config_dict['normalise_state']:
                    next_obs = (next_obs - agent.replay_buffer.mean)/agent.replay_buffer.std



                if agent.algo_name not in ['revalued_action', 'decqn_action']:
                    agent.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    agent.dataset_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                    n_step_buffer.append((obs,act,reward))
                    
                    if agent.n_steps == 1:
                        agent.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                    else:
                        if len(n_step_buffer) == agent.n_steps:
                            state_0, action_0, _ =  n_step_buffer[0]
                            disc_returns = np.sum([r * agent.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                            agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                true_next_state=n_step_buffer[1][0])
                            n_step_buffer.pop(0)

                    if done:
                        while n_step_buffer:
                            state_0, action_0, _ = n_step_buffer[0]
                            disc_returns = np.sum([r * agent.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                            try:
                                agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                    true_next_state=n_step_buffer[1][0])
                            except IndexError:
                                agent.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                            n_step_buffer.pop(0)

                if grad_steps%config_dict['update_ratio'] == 0:
                    loss = agent.learn()


                obs = next_obs


                if grad_steps%config_dict['eval_counter'] == 0:
                    avg_return = agent.evaluate_performance(config_dict,grad_steps)
                    print('agent loss:',loss)

        if agent.algo_name not in ['revalued_action', 'decqn_action']:
            agent.replay_buffer.save_buffer(config_dict)
        else:
            agent.dataset_buffer.save_buffer(config_dict)

    if 'medium-replay' in config_dict['env_id']:
        extra_env_id = config_dict['task']+'-medium-expert-v2'
        extra_env = gym.make(extra_env_id)
        extra_data = get_dataset(extra_env, config_dict)
        agent.replay_buffer.store_offline_data(extra_data)
    elif 'medium-v2' in config_dict['env_id']:
        extra_env_id = config_dict['task']+'-expert-v2'
        extra_env = gym.make(extra_env_id)
        extra_data = get_dataset(extra_env,config_dict)
        agent.replay_buffer.store_offline_data(extra_data)
    elif config_dict['dm_suite']:
        dataset = get_dataset(config_dict['env_id'],config_dict)
        agent.dataset_buffer.store_offline_data(dataset)


    save_num = f'1000000_big_eps-{agent.replay_buffer.discrete_eps}_bins-{agent.replay_buffer.discrete_bins}'
    if config_dict['dm_suite']:
        inv_ss.train_inverse(agent,save_num=save_num,train_iter=150000)
    else:
        inv_ss.train_inverse(agent,save_num=save_num)
