from itertools import count
from pathlib import Path
import hydra
import torch
import numpy as np
import os
from omegaconf import DictConfig, OmegaConf
from scipy.stats import spearmanr, pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import pandas as pd
import pickle
import csv

from make_envs import make_env
from agent import make_agent
from utils.utils import evaluate

def get_args(cfg: DictConfig):
    cfg.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(OmegaConf.to_yaml(cfg))
    return cfg


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)

    weight = torch.FloatTensor(args.agent.preference).to(cfg.device)


    env = make_env(args, weight.cpu().numpy(), is_mogym=args.env.is_mogym, render=args.env.render)
    agent = make_agent(env, args)

    if args.method.type == "sqil":
        name = f'sqil'
    else:
        name = f'iq'

    policy_file = f'results/{args.method.type}.para'
    if args.eval.policy:
        policy_file = f'{args.eval.policy}'
    print(f'Loading policy from: {policy_file}')

    if args.eval.transfer:
        agent.load(hydra.utils.to_absolute_path(policy_file), f'_{name}_{args.eval.expert_env}')
    else:
        agent.load(hydra.utils.to_absolute_path(policy_file), f'_{name}_{args.env.name}')

    eval_returns, eval_timesteps, eval_returns_dict = evaluate(agent, weight, env, num_episodes=args.eval.eps)
    result_str = f'Avg. eval returns: {np.mean(eval_returns)}, timesteps: {np.mean(eval_timesteps)}'
    for i in range(len(eval_returns_dict)):
        result_str += ', returns_dim{}: {}'.format(i, np.mean(eval_returns_dict['returns_dim{}'.format(i)]))
    print(result_str)

    log_file = Path('./{}_eval_log.pkl'.format(args.env.name))
    if os.path.exists(str(log_file)):
        with open(str(log_file), 'rb') as fp:
            output_dict = pickle.load(fp)
    else:
        output_dict = {}
            

    tmp = {'Eval returns - '+str(args.agent.preference):np.mean(eval_returns), 'timesteps - '+str(args.agent.preference):np.mean(eval_timesteps)}
    for key in eval_returns_dict:
        tmp[key+' - '+str(args.agent.preference)] = np.mean(eval_returns_dict[key])

    for key in tmp:
        if key in output_dict:
            output_dict[key].append(tmp[key])
        else:
            output_dict[key] = [tmp[key]]

    with open(str(log_file), 'wb') as fp:
        pickle.dump(output_dict, fp)
    if args.eval_only:
        exit()

    measure_correlations(agent, weight, env, args, log=True)


def measure_correlations(agent, weight, env, args, log=False, use_wandb=False):
    GAMMA = args.gamma

    env_rewards = []
    learnt_rewards = []

    for epoch in range(100):

        part_env_rewards = []
        part_learnt_rewards = []

        state, info = env.reset()
        episode_reward = 0
        episode_irl_reward = 0

        for time_steps in count():
            action = agent.choose_action(state, weight, sample=False)
            next_state, reward, terminated, truncated, _ = env.step(action)

            done = terminated or truncated

            # Get sqil reward
            with torch.no_grad():
                q = agent.infer_q(state, action, weight)
                next_v = agent.infer_v(next_state, weight)
                y = (1 - done) * GAMMA * next_v
                irl_reward = (q - y) @ weight.cpu().numpy().T

            episode_irl_reward += irl_reward.item()
            episode_reward += reward
            part_learnt_rewards.append(irl_reward.item())
            part_env_rewards.append(reward)

            if done:
                break
            state = next_state

        if log:
            print('Ep {}\tEpisode env rewards: {:.2f}\t'.format(epoch, episode_reward))
            print('Ep {}\tEpisode learnt rewards {:.2f}\t'.format(epoch, episode_irl_reward))

        learnt_rewards.append(part_learnt_rewards)
        env_rewards.append(part_env_rewards)

    # mask = [sum(x) < -5 for x in env_rewards]  # skip outliers
    # env_rewards = [env_rewards[i] for i in range(len(env_rewards)) if mask[i]]
    # learnt_rewards = [learnt_rewards[i] for i in range(len(learnt_rewards)) if mask[i]]

    print(f'Spearman correlation {spearmanr(eps(learnt_rewards), eps(env_rewards))}')
    print(f'Pearson correlation: {pearsonr(eps(learnt_rewards), eps(env_rewards))}')

    # plt.show()
    savedir = hydra.utils.to_absolute_path(f'vis/{args.env.name}/correlation')
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    sns.set()
    plt.figure(dpi=150)
    plt.scatter(eps(env_rewards), eps(learnt_rewards), s=10, alpha=0.8)
    plt.xlabel('Env rewards')
    plt.ylabel('Recovered rewards')
    if use_wandb:
        wandb.log({f"Episode rewards": wandb.Image(plt)})
    plt.savefig(savedir + '/%s.png' % 'Episode rewards')
    plt.close()

    sns.set()
    plt.figure(dpi=150)
    for i in range(20):
        plt.scatter(part_eps(env_rewards)[i], part_eps(learnt_rewards)[i], s=5, alpha=0.6)
    plt.xlabel('Env rewards')
    plt.ylabel('Recovered rewards')
    if use_wandb:
        wandb.log({f"Partial rewards": wandb.Image(plt)})
    plt.savefig(savedir + '/%s.png' % 'Partial rewards')
    plt.close()

    sns.set()
    plt.figure(dpi=150)
    for i in range(20):
        plt.plot(part_eps(env_rewards)[i], part_eps(learnt_rewards)[i], markersize=1, alpha=0.8)
    plt.xlabel('Env rewards')
    plt.ylabel('Recovered rewards')
    if use_wandb:
        wandb.log({f"Partial rewards - Interplolate": wandb.Image(plt)})
    plt.savefig(savedir + '/%s.png' % 'Partial rewards - Interplolate')
    plt.close()

    sns.set()
    plt.figure(dpi=150)
    for i in range(5):
        plt.scatter(env_rewards[i], learnt_rewards[i], s=5, alpha=0.5)
    plt.xlabel('Env rewards')
    plt.ylabel('Recovered rewards')
    if use_wandb:
        wandb.log({f"Step rewards": wandb.Image(plt)})
    plt.savefig(savedir + '/%s.png' % 'Step rewards')
    plt.close()


def eps(rewards):
    return [sum(x) for x in rewards]


def part_eps(rewards):
    return [np.cumsum(x) for x in rewards]


if __name__ == '__main__':
    main()
