# in this file, we test how well we can personalize a model to a specific individual without training
import copy
import os

import gymnasium as gym
import hydra
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
from matplotlib.colors import LogNorm, Normalize
from minigrid.wrappers import ImgObsWrapper
from tqdm import tqdm, trange

from experiments.envs import ManyDoorsEnv, TwoDoorsEnv
from experiments.envs.metaworld import (MetaWorldSafetySpeedWrapper,
                                        MetaWorldSawyerEnv)
from experiments.envs.wrappers import OneHotFullImage, OneHotPartialImage
from experiments.rl_utils import (MetaworldPolicy, Policy, sample_batch,
                                  sample_data)
from src.lexicase import select_from_scores
from src.popl import calc_lexicase_scores, popl_search


@hydra.main(config_path="config", config_name="personalization")
def main(cfg):
    # TODO: make all params hydrable
    input_size = None
    output_size = 7

    lex_scores = []
    cpl_scores = []

    lr = 0.2
    steps = 2000
    resamples = 100

    scores_of_top_policies_per_identity = []

    if cfg.env == "manydoors":
        env = ManyDoorsEnv(render_mode="none")
        env = ImgObsWrapper(env)
        channels = env.observation_space.shape[0]
    elif cfg.env == "twodoors":
        env = TwoDoorsEnv(render_mode="none")
        env = ImgObsWrapper(env)
        channels = env.observation_space.shape[0]
    elif cfg.env[-2:] == "v2":
        # metaworld env
        env = MetaWorldSawyerEnv(cfg.env)
        env = MetaWorldSafetySpeedWrapper(env, 1)

    # load all the policies
    all_last_layers = []
    all_policies = []
    for repeat in trange(cfg.repeats, desc="repeat"):
        last_layers = []
        all_policies = []
        for ind in range(cfg.popsize):
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
            if cfg.env == "manydoors" or cfg.env == "twodoors":
                policy = Policy(input_size, channels, cfg.num_features,
                                output_size).to(device)
                action_space = 7
            elif cfg.env[-2:] == "v2":
                policy = MetaworldPolicy(
                    35, cfg.num_features, 4).to(device)
                action_space = 4

            policy.load_state_dict(torch.load(
                f"{hydra.utils.get_original_cwd()}/rl_models/{cfg.foldername}/{cfg.method}_policy_{ind}.pth"))

            all_policies.append(policy)

            last_layer = torch.Tensor(policy.last_layer.weight).T

            last_layers.append(last_layer)

        last_layers = torch.stack(last_layers)
        all_last_layers.append(last_layers)
    all_last_layers = torch.stack(all_last_layers)

    print(f"all_last_layers shape: {all_last_layers.shape}")

    if cfg.env == "twodoors":
        state_visitations = np.zeros((2, 5, 5))

    elif cfg.env == "manydoors":
        state_visitations = np.zeros((2, 13, 13))

    scores = np.zeros((2, cfg.repeats, cfg.popsize))

    for iden in [1, 2]:
        # all_policies = []
        states_visited = []

        # do one more selection step, now using prefs from one group
        policies = []
        if cfg.env == "manydoors" or cfg.env == "twodoors":
            policy = Policy(input_size, channels, cfg.num_features,
                            output_size).to(device)
            action_space = 7
        if cfg.env[-2:] == "v2":
            policy = MetaworldPolicy(
                35, cfg.num_features, 4).to(device)
            action_space = 4

        # TODO find the right file name
        policy.load_state_dict(torch.load(
            f"{hydra.utils.get_original_cwd()}/rl_models/{cfg.foldername}/{cfg.method}_policy_{ind}.pth"))


        policy_trajs = []

        all_policies = [[policy] * cfg.repeats for policy in all_policies]
        for policy_i, policies in enumerate(all_policies):
            for repeat_i, policy in tqdm(enumerate(policies), desc=f"policy_{policy_i}", total=len(policies)):

                for i in range(1):
                    traj_states = []

                    if cfg.env == "twodoors":
                        env = TwoDoorsEnv(render_mode="none", identity=iden)
                        env = ImgObsWrapper(env)
                    elif cfg.env == "manydoors":
                        env = ManyDoorsEnv(render_mode="none", identity=iden)
                        env = ImgObsWrapper(env)
                    elif cfg.env[-2:] == "v2":
                        env = MetaWorldSawyerEnv(cfg.env)
                        env = MetaWorldSafetySpeedWrapper(env, iden)

                    obs, _ = env.reset()
                    done = False
                    trunc = False

                    old_obs = None
                    if cfg.env in ["twodoors", "manydoors"]:
                        traj_states.append(env.agent_pos)
                    tot_reward = 0

                    if (type(policy) == list):
                        print(f"policy[0]: {policy[0]}")
                        policy = policy[0]
                    while not done and not trunc:
                        if cfg.env[-2:] == "v2":
                            obs = torch.Tensor(obs).to(device)
                            action = policy(obs).detach().cpu().numpy()
                        else:
                            obs = torch.Tensor(obs).to(device).unsqueeze(0)
                            action = policy(obs).to(device).argmax()

                        old_obs = obs
                        obs, reward, done, trunc, _ = env.step(action)
                        tot_reward += reward

                        if cfg.env in ["twodoors", "manydoors"]:
                            traj_states.append(env.agent_pos)

                        if cfg.render:
                            env.render()

                    if done and cfg.env in ["twodoors", "manydoors"]:
                        traj_states.append(env.agent_pos)

                    if cfg.env == "twodoors":
                        r_1 = env.door1_opened
                        r_2 = env.door2_opened

                        if r_1:
                            print("Door 1 opened!")

                        if r_2:
                            print("Door 2 opened!")

                    elif cfg.env == "manydoors":
                        r_1 = 0
                        r_2 = 0

                        for door_1, door_2 in env.door_pairs:
                            if door_2.is_locked:  # if door 1 was ever opened
                                r_1 += 1
                            if door_1.is_locked:
                                r_2 += 1

                    elif cfg.env[-2:] == "v2":
                        pass

                    scores[iden-1, repeat_i, policy_i] = tot_reward
                    policy_trajs.append(traj_states)


        if cfg.env in ["twodoors", "manydoors"]:
            # now, we can have a density map of the states visited
            if cfg.env == "twodoors":
                state_density = np.zeros((5, 5))
            elif cfg.env == "manydoors":
                state_density = np.zeros((13, 13))

            for traj in states_visited:
                for state in traj:
                    state_density[state[0], state[1]] += 1/(len(traj))

            state_density = state_density.T  # transpose to match the gridworld

            state_visitations[iden-1] = state_density

            np.save(
                f"{cfg.env}_{cfg.method}_new_state_density_{iden}_{repeat_i}.npy", state_density)

            # plot as a heatmap
            plt.figure()
            sns.heatmap(state_density, annot=False, cmap=sns.cubehelix_palette(start=(
                0.5 if iden == 1 else 4), rot=0, dark=0.4, light=1, as_cmap=True))  # norm=LogNorm())
            # remove x and y ticks
            plt.xticks([])
            plt.yticks([])
            plt.savefig(
                f"{cfg.env}_{cfg.method}_state_density_{iden}_{repeat_i}.png", dpi=300)
        elif cfg.env[-2:] == "v2":
            # meta world analysis will just be:
            # look at the reward of the best policies for each group
            pass

    if cfg.env in ["twodoors", "manydoors"]:
        # plot the average difference between the two different state densities
        diff_0 = state_visitations[0] - np.mean(state_visitations, axis=0)
        diff_1 = state_visitations[1] - np.mean(state_visitations, axis=0)

        plt.figure(figsize=(6, 6))
        sns.heatmap(diff_0, annot=False, cmap="coolwarm")
        plt.savefig("new_diff_0.png", dpi=300)
        # save publication quality figure

        plt.figure(figsize=(6, 6))
        sns.heatmap(diff_1, annot=False, cmap="coolwarm")
        plt.savefig("new_diff_1.png", dpi=300)

    if cfg.env[-2:] == "v2":
        # boxplot of scores per identity
        plt.figure()
        identity_1 = scores[0]
        print(f"identity_1: {identity_1}")
        identity_2 = scores[1]

        plt.subplot(2, 1, 1)
        plt.boxplot(identity_1)

        plt.subplot(2, 1, 2)
        plt.boxplot(identity_2)

        plt.savefig(f"{cfg.env}_{cfg.method}_scores.png", dpi=300)

    # save the scores
    np.save(f"{cfg.env}_{cfg.method}_scores.npy", scores)


if __name__ == "__main__":
    main()
