import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from utils import MeanStdevFilter, Transition, make_checkpoint, make_name, Datum, make_gif 


def train_agent_model_free(agent, env, params, writer=None, log_interval=1000, gif_interval=50000):
    
    update_timestep = params['update_every_n_steps']
    seed = params['seed']
    log_interval = log_interval
    gif_interval = gif_interval
    n_random_actions = params['n_random_actions']
    n_evals = params['n_evals']
    n_collect_steps = params['n_collect_steps']
    use_statefilter = params['obs_filter']
    save_model = params['save_model']
    total_steps = params['total_steps']

    assert n_collect_steps > agent.batchsize, "We must initially collect as many steps as the batch size!"

    avg_length = 0
    time_step = 0
    cumulative_timestep = 0
    cumulative_log_timestep = 0
    n_updates = 0
    i_episode = 0
    log_episode = 0
    samples_number = 0
    episode_rewards = []
    episode_steps = []

    if use_statefilter:
        state_filter = MeanStdevFilter(env.env.observation_space.shape[0])
    else:
        state_filter = None

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
    env.action_space.np_random.seed(seed)

    max_steps = env.spec.max_episode_steps
    # beta scheduling
    beta_warmup, beta_pretrain = params['beta_warmup'], params['beta_pretrain']
    beta_max = params['beta']
    assert beta_warmup < 1 - beta_pretrain, "don't have enough updates left to warm up beta!"
    beta_start_pretrain = int(beta_pretrain * total_steps)
    beta_num_warmup = int(beta_warmup * total_steps)
    beta_warmup = list(np.linspace(0, beta_max, num=beta_num_warmup))
    beta_num_plateau = total_steps - beta_start_pretrain - beta_num_warmup
    beta_schedule = beta_start_pretrain * [0.0] + beta_warmup + beta_num_plateau * [beta_max]
    assert len(beta_schedule) == total_steps, "damn"
    default_start = int(params['default_start'] * total_steps)

    name = make_name(params)
    if writer is None:
        writer = SummaryWriter(log_dir=f"{params['save_dir']}/{params['agent']}_runs/{params['env']}/{params['task']}/" + name)

    while samples_number < total_steps+1:
        time_step = 0
        episode_reward = 0
        i_episode += 1
        log_episode += 1
        state = env.reset()
        if state_filter:
            state_filter.update(state)
        done = False

        while (not done):
            cumulative_log_timestep += 1
            cumulative_timestep += 1
            time_step += 1
            samples_number += 1
            if samples_number < n_random_actions:
                action = env.action_space.sample()
                pi_mean, pi_std = torch.zeros_like(torch.tensor(action)), torch.zeros_like(torch.tensor(action))
            else:
                action, pi_mean, pi_std = agent.get_action(state, state_filter=state_filter, get_dist=True)
            nextstate, reward, done, _ = env.step(action)
            # if we hit the time-limit, it's not a 'real' done; we don't want to assign low value to those states
            real_done = False if time_step == max_steps else done
            agent.replay_pool.push(Transition(state, action, reward, nextstate, real_done))
            if samples_number >= default_start:
                agent.default_replay_pool.push(
                    Datum(state, pi_mean.detach().cpu().numpy(), pi_std.detach().cpu().numpy())
                    )
            state = nextstate
            if state_filter:
                state_filter.update(state)
            episode_reward += reward
            # update if it's time
            beta = beta_schedule[samples_number] if samples_number < len(beta_schedule) else beta_schedule[-1]
            if cumulative_timestep % update_timestep == 0 and cumulative_timestep > n_collect_steps:
                optim_kwargs = {
                    "update_default": False,
                    "beta": beta,
                    "use_kl": True if params['agent'] == 'mdlc-sac' else False,
                    "state_filter": state_filter
                    }
                losses = agent.optimize(
                    update_timestep, **optim_kwargs
                    )

                q1_loss, q2_loss, pi_loss, a_loss = losses
                n_updates += 1
            # logging
            if cumulative_timestep % log_interval == 0 and cumulative_timestep > n_collect_steps:
                writer.add_scalar(f"{params['task']}-Loss/Q-func_1", q1_loss, n_updates)
                writer.add_scalar(f"{params['task']}-Loss/Q-func_2", q2_loss, n_updates)
                writer.add_scalar(f"{params['task']}-Loss/control_policy", pi_loss, n_updates)
                writer.add_scalar(f"{params['task']}-Loss/alpha", a_loss, n_updates)
                writer.add_scalar(f"{params['task']}-Values/alpha", np.exp(agent.log_alpha.item()), n_updates)
                avg_length = np.mean(episode_steps)
                running_reward = np.mean(episode_rewards)
                eval_reward = evaluate_agent(env, agent, state_filter, n_starts=n_evals)
                writer.add_scalar(f"{params['task']}-Reward/Train", running_reward, cumulative_timestep)
                writer.add_scalar(f"{params['task']}-Reward/Test", eval_reward, cumulative_timestep)
                print('Episode {} \t Samples {} \t Avg length: {} \t Test reward: {} \t Train reward: {} \t Number of Policy Updates: {}'.format(i_episode, samples_number, avg_length, eval_reward, running_reward, n_updates))
                episode_steps = []
                episode_rewards = []

            if (cumulative_timestep + 1) % gif_interval == 0:
                make_gif(agent, env, cumulative_timestep, state_filter)
                if save_model:
                    tstep = None if not params['learned_asymmetry'] else cumulative_timestep+1
                    make_checkpoint(agent, params, timestep=tstep)

        episode_steps.append(time_step)
        episode_rewards.append(episode_reward)

    return agent


def evaluate_agent(env, agent, state_filter, n_starts=1):
    reward_sum = 0
    for _ in range(n_starts):
        done = False
        state = env.reset()
        while (not done):
            action = agent.get_action(state, state_filter=state_filter, deterministic=True)
            nextstate, reward, done, _ = env.step(action)
            reward_sum += reward
            state = nextstate
    return reward_sum / n_starts

