from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import argparse
import json
from utils import redirect_stdout, TrainLogger
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
import os, time
import random
from actor_critic import DDPG, TD3, SAC
from collect_trajectory import load_agent
import train_reward_model
import train_behavior_model
import copy
import sys
from pathlib import Path

def load_stochastic_behavior_model(env_name, data_num, seed):
    behavior_model = train_behavior_model.stochastic_behavior_model(env_name= env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    behavior_path = Path('../stochastic_behavior/%s_%d_%d' % (env_name, data_num, seed))
    behavior_path.parent.mkdir(parents=True, exist_ok=True)
    if not behavior_path.is_file():
       raise RuntimeError('Behavior path doesnt exist')
    else:
        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))
    return behavior_model

def baseline_training(dataset, agent, epochs, steps_per_epoch, seed, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger(), true_reward = False):
    true_pfms = []
    sim_pfms = []

    reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    if not true_reward:
        reward_path = Path(f'../reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
    else:
        reward_path = Path(f'../true_reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training_true_reward(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
    import pdb;pdb.set_trace()
    total_steps = steps_per_epoch * epochs

    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0
    for t in range(total_steps):
        if t > start_steps:
            a = agent.get_action(o)
        else:
            a = agent.env.action_space.sample()

        o2, _, d, _ = agent.env.step(a)

        with torch.no_grad():
            r = reward_model.reward_nn(
                torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                torch.as_tensor(a, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))
        ep_ret += r
        ep_len += 1
        d = False if ep_len == max_ep_len else d
        agent.replay_buffer.store(o, a, r, o2, d)
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.log({'episode:': episode, 'training return:': ep_ret})
            episode += 1
            o, ep_len, ep_ret = agent.env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = agent.replay_buffer.sample_batch(batch_size)
                if agent.name == 'td3':
                    agent.update(data=batch, timer=j)
                elif agent.name == 'ddpg' or 'sac':
                    agent.update(data=batch)

        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch
            # Test the performance of the deterministic version of the agent.
            ep_rets = agent.test_agent()
            ep_rets_sim = agent.test_reward(reward_model)
            pfm = sum(ep_rets) / len(ep_rets)
            sim_pfm = sum(ep_rets_sim) / len(ep_rets_sim)
            logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
            
            true_pfms.append(pfm)
            sim_pfms.append(sim_pfm)

    return true_pfms, sim_pfms

def baseline_training_ppo(dataset, agent, epochs, steps_per_epoch, seed, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger(), true_reward = False):
    true_pfms = []
    sim_pfms = []

    reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    if not true_reward:
        reward_path = Path(f'../reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
    else:
        reward_path = Path(f'../true_reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training_true_reward(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
        
    
    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0
    for epoch in range(epochs):
        for t in range(steps_per_epoch):   
            a, v, logp = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))
            
            o2, r, d, _ = agent.env.step(a)
            
            with torch.no_grad():
                r = reward_model.reward_nn(
                torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                torch.as_tensor(a, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))
            ep_ret += r
            ep_len += 1

            agent.replay_buffer.store(o, a, r, v, logp)
            o = o2

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t==steps_per_epoch-1
            
            if terminal or epoch_ended:
                if epoch_ended and not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                
                if timeout or epoch_ended:
                    _, v, _ = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
               
                agent.replay_buffer.finish_path(v)
                if terminal:
                    logger.log({'episode:': episode, 'training return:': ep_ret})
                episode += 1
                o, ep_ret, ep_len = agent.env.reset(), 0, 0

        if agent.replay_buffer.ptr == agent.replay_buffer.max_size:
            agent.update()
        
        ep_rets = agent.test_agent()
        ep_rets_sim = agent.test_reward(reward_model)
        pfm = sum(ep_rets) / len(ep_rets)
        sim_pfm = sum(ep_rets_sim) / len(ep_rets_sim)
        logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
        true_pfms.append(pfm)
        sim_pfms.append(sim_pfm)
        
    return true_pfms, sim_pfms

def soft_behavior_regularized_training_ppo(dataset, agent, epochs, steps_per_epoch, seed, beta=0.1, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger(), true_reward = False):

    true_pfms = []
    sim_pfms = []

    data_num = len(dataset[3])

    reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    if not true_reward:
        reward_path = Path(f'../reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
    else:
        reward_path = Path(f'../true_reward/regular/{agent.env_name}_{len(dataset[3])}_{seed}')
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            train_reward_model.training_true_reward(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger=logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))

    behavior_model = train_behavior_model.stochastic_behavior_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    behavior_path = Path('../stochastic_behavior/%s_%d_%d' % (agent.env_name, data_num, seed))
    behavior_path.parent.mkdir(parents=True, exist_ok=True)
    if not behavior_path.is_file():
        train_behavior_model.training(behavior_model, dataset, epoch=150, epoch_steps=100, batch_size=64, logger = logger)
        torch.save(behavior_model.action_nn.state_dict(), behavior_path)
    else:
        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))

    behavior_model.test()

    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0
    for epoch in range(epochs):
        for t in range(steps_per_epoch):  
            a, v, logp = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))

            o2, r, d, _ = agent.env.step(a)

            with torch.no_grad():
                r = reward_model.reward_nn(
                    torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                    torch.as_tensor(a, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))

                r += beta * behavior_model.action_nn(torch.as_tensor(o, dtype=torch.float32).clone().detach().to(behavior_model.action_nn.device),
                    torch.as_tensor(a, dtype=torch.float32).clone().detach().to(behavior_model.action_nn.device))

            ep_ret += r
            ep_len += 1
            d = False if ep_len == max_ep_len else d
            agent.replay_buffer.store(o, a, r, v, logp)
            o = o2

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t==steps_per_epoch-1
            
            if terminal or epoch_ended:
                if epoch_ended and not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                
                if timeout or epoch_ended:
                    _, v, _ = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
               
                agent.replay_buffer.finish_path(v)
                if terminal:
                    logger.log({'episode:': episode, 'training return:': ep_ret})
                episode += 1
                o, ep_ret, ep_len = agent.env.reset(), 0, 0

        if agent.replay_buffer.ptr == agent.replay_buffer.max_size:
            agent.update()
        
        # Test the performance of the deterministic version of the agent.
        ep_rets = agent.test_agent()
        ep_rets_sim = agent.test_sim_soft(beta, behavior_model, reward_model)
        sim_pfm = sum(ep_rets_sim)/len(ep_rets_sim)
        pfm = sum(ep_rets) / len(ep_rets)
        logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
        true_pfms.append(pfm)
        sim_pfms.append(sim_pfm)

    return true_pfms, sim_pfms

def uncertainty_training(dataset, agent, epochs, steps_per_epoch, seed, beta=5, seg_num=4, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger()):

    true_pfms = []
    sim_pfms = []

    data_num = len(dataset[3])
    seg = int(data_num / seg_num)
    reward_models = []
    for i in range(seg_num):
        reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
        reward_path = Path('../reward/ensemble/%s_%d_%d_%d' % (agent.env_name, data_num, i, seed))
        reward_path.parent.mkdir(parents=True, exist_ok=True)
        if not reward_path.is_file():
            data_seg = list(dataset[:3]) + [dataset[j][i*seg:(i+1)*seg] for j in range(3, 6)]
            train_reward_model.training(reward_model, data_seg, epoch=150, epoch_steps=100, batch_size=64, logger= logger)
            torch.save(reward_model.reward_nn.state_dict(), reward_path)
        else:
            reward_model.reward_nn.load_state_dict(torch.load(reward_path))
        reward_models.append(reward_model)

    total_steps = steps_per_epoch * epochs
    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0

    for t in range(total_steps):
        if t > start_steps:
            a = agent.get_action(o)
        else:
            a = agent.env.action_space.sample()

        o2, _, d, _ = agent.env.step(a)

        with torch.no_grad():
            r_ensemble = []
            for reward_model in reward_models:
                prediction = reward_model.reward_nn(torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                                       torch.as_tensor(a, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))
                r_ensemble.append(prediction.item())
            r_ensemble = np.array(r_ensemble)
            r = np.mean(r_ensemble) - beta * np.var(r_ensemble)
            # print(np.mean(r_ensemble)/np.var(r_ensemble))

        ep_ret += r
        ep_len += 1
        d = False if ep_len == max_ep_len else d
        agent.replay_buffer.store(o, a, r, o2, d)
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.log({'episode:': episode, 'training return:': ep_ret})
            episode += 1
            o, ep_len, ep_ret = agent.env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = agent.replay_buffer.sample_batch(batch_size)
                if agent.name == 'td3':
                    agent.update(data=batch, timer = j)
                elif agent.name == 'ddpg' or 'sac':
                    agent.update(data=batch)

        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch
            # Test the performance of the deterministic version of the agent.
            ep_rets = agent.test_agent()
            ep_rets_sim = agent.test_rewards(reward_models, beta)
            pfm = sum(ep_rets) / len(ep_rets)
            sim_pfm = sum(ep_rets_sim) / len(ep_rets_sim)
            logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
            true_pfms.append(pfm)
            sim_pfms.append(sim_pfm)

    return true_pfms, sim_pfms

def behavior_regularized_training(dataset, agent, epochs, steps_per_epoch, seed, regu=0.1, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger()):

    true_pfms = []
    sim_pfms = []

    data_num = len(dataset[3])

    reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    reward_path = Path('../reward/regular/%s_%d_%d' % (agent.env_name, data_num, seed)) 
    reward_path.parent.mkdir(parents=True, exist_ok=True)
    if not reward_path.is_file():
        train_reward_model.training(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, reward_path=reward_path, logger = logger)
        torch.save(reward_model.reward_nn.state_dict(), reward_path)
    else:
        reward_model.reward_nn.load_state_dict(torch.load(reward_path))

    # reward_model.test(dataset)

    behavior_model = train_behavior_model.behavior_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    behavior_path = Path('../behavior/%s_%d_%d' % (agent.env_name, data_num, seed))
    behavior_path.parent.mkdir(parents=True, exist_ok=True)
    if not behavior_path.is_file():
        train_behavior_model.training(behavior_model, dataset, epoch=150, epoch_steps=100, batch_size=64, behavior_path=behavior_path, logger = logger)
        torch.save(behavior_model.action_nn.state_dict(), behavior_path)
    else:
        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))

    # reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    # behavior_model = train_behavior_model.behavior_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    # behavior_path = '../behavior/%s_%s_%s_%s_%d_%d' % (
    # agent.env_name, traj_type, sample_type, pref_type, data_num, 3)
    # behavior_model.action_nn.load_state_dict(torch.load(behavior_path))

    behavior_model.test()

    total_steps = steps_per_epoch * epochs

    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0
    for t in range(total_steps):
        if t > start_steps:
            a = agent.get_action(o)
        else:
            a = agent.env.action_space.sample()

        with torch.no_grad():
            base_action = behavior_model.action_nn(torch.as_tensor(o, dtype=torch.float32).clone().detach().to(behavior_model.action_nn.device)).cpu().detach().numpy()

        a_actual = a * regu + base_action

        o2, _, d, _ = agent.env.step(a_actual)

        with torch.no_grad():
            r = reward_model.reward_nn(
                torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                torch.as_tensor(a_actual, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))

        ep_ret += r
        ep_len += 1
        d = False if ep_len == max_ep_len else d
        agent.replay_buffer.store(o, a, r, o2, d)
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.log({'episode:': episode, 'training return:': ep_ret})
            episode += 1
            o, ep_len, ep_ret = agent.env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = agent.replay_buffer.sample_batch(batch_size)
                if agent.name == 'td3':
                    agent.update(data=batch, timer=j)
                elif agent.name == 'ddpg' or 'sac':
                    agent.update(data=batch)

        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch
            ep_rets = agent.test_agent_behavior(regu, behavior_model)
            ep_rets_sim = agent.test_sim(regu, behavior_model, reward_model)
            pfm = sum(ep_rets) / len(ep_rets)
            sim_pfm = sum(ep_rets_sim)/len(ep_rets_sim)
            logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
            true_pfms.append(pfm)
            sim_pfms.append(sim_pfm)
    return true_pfms, sim_pfms

def soft_behavior_regularized_training(dataset, agent, epochs, steps_per_epoch, seed, beta=0.1, batch_size=100, start_steps=10000,
         update_after=1000, update_every=50, max_ep_len=1000, logger = TrainLogger()):

    true_pfms = []
    sim_pfms = []

    data_num = len(dataset[3])

    reward_model = train_reward_model.reward_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    reward_path = Path('../reward/regular/%s_%d_%d' % (agent.env_name, data_num, seed))
    reward_path.parent.mkdir(parents=True, exist_ok=True)
    if not reward_path.is_file():
        train_reward_model.training(reward_model, dataset, epoch=150, epoch_steps=100, batch_size=64, logger = logger)
        torch.save(reward_model.reward_nn.state_dict(), reward_path)
    else:
        reward_model.reward_nn.load_state_dict(torch.load(reward_path))

    behavior_model = train_behavior_model.stochastic_behavior_model(env_name=agent.env_name, ac_kwargs=dict(hidden_sizes=[64] * 3))
    behavior_path = Path('../stochastic_behavior/%s_%d_%d' % (agent.env_name, data_num, seed))
    behavior_path.parent.mkdir(parents=True, exist_ok=True)
    if not behavior_path.is_file():
        train_behavior_model.training(behavior_model, dataset, epoch=150, epoch_steps=100, batch_size=64, logger = logger)
        torch.save(behavior_model.action_nn.state_dict(), behavior_path)
    else:
        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))

    behavior_model.test()

    total_steps = steps_per_epoch * epochs

    o, ep_len, ep_ret = agent.env.reset(), 0, 0
    episode = 0
    for t in range(total_steps):
        if t > start_steps:
            a = agent.get_action(o)
        else:
            a = agent.env.action_space.sample()

        o2, _, d, _ = agent.env.step(a)

        with torch.no_grad():
            r = reward_model.reward_nn(
                torch.as_tensor(o, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device),
                torch.as_tensor(a, dtype=torch.float32).clone().detach().to(reward_model.reward_nn.device))

            r += beta * behavior_model.action_nn(torch.as_tensor(o, dtype=torch.float32).clone().detach().to(behavior_model.action_nn.device),
                torch.as_tensor(a, dtype=torch.float32).clone().detach().to(behavior_model.action_nn.device))

        ep_ret += r
        ep_len += 1
        d = False if ep_len == max_ep_len else d
        agent.replay_buffer.store(o, a, r, o2, d)
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.log({'episode:': episode, 'training return:': ep_ret})
            episode += 1
            o, ep_len, ep_ret = agent.env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = agent.replay_buffer.sample_batch(batch_size)
                if agent.name == 'td3':
                    agent.update(data=batch, timer=j)
                elif agent.name == 'ddpg' or 'sac':
                    agent.update(data=batch)

        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch
            # Test the performance of the deterministic version of the agent.
            ep_rets = agent.test_agent()
            ep_rets_sim = agent.test_sim_soft(beta, behavior_model, reward_model)
            sim_pfm = sum(ep_rets_sim)/len(ep_rets_sim)
            pfm = sum(ep_rets) / len(ep_rets)
            logger.log({'epoch': epoch, 'test_true_pfm': pfm, 'sim_pfm': sim_pfm})
            true_pfms.append(pfm)
            sim_pfms.append(sim_pfm)

    return true_pfms, sim_pfms