import argparse
import os

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
from minigrid.wrappers import ImgObsWrapper

from experiments.envs import ManyDoorsEnv, TwoDoorsEnv
from experiments.envs.metaworld import (MetaWorldSafetySpeedWrapper,
                                        MetaWorldSawyerEnv)
from experiments.envs.wrappers import OneHotFullImage, OneHotPartialImage
from experiments.rl_utils import (MetaworldPolicy, Policy, sample_all,
                                  sample_batch, sample_data, train_bc,
                                  train_cpl, train_cpl_last_layers,
                                  train_popl_reward, train_popl)
from src.popl import popl_policy_search


def gaussian_mutation(population, mutation_rate):
    # individual is a tensor of shape (popsize, hidden_dim)
    # mutation rate is a float between 0 and 1
    # we will mutate each element of the individual with probability mutation_rate

    mutated_population = population.clone()
    # add gaussian noise mean 0 std mutation_rate
    mutated_population += torch.randn_like(mutated_population) * mutation_rate

    return mutated_population


def index_mutation(population, mutation_rate):
    # pick indices to mutate the individuals for each individual
    indices = torch.rand_like(population) < mutation_rate
    masked_pop = torch.zeros_like(population)
    masked_pop[indices] = 1

    # add gaussian noise to the masked population

    mutated_population = population + torch.randn_like(population) * masked_pop

    return mutated_population


@hydra.main(config_path="config", config_name="rl_domains")
def main(cfg):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if cfg.env_name == "twodoors":
        env = TwoDoorsEnv(render_mode="none")
        env = ImgObsWrapper(env)
        input_size = None
        channels = env.observation_space.shape[0]
        output_size = 7
        policy = Policy(input_size, channels, cfg.num_features, output_size)
    elif cfg.env_name == "manydoors":
        env = ManyDoorsEnv(render_mode="none")
        env = ImgObsWrapper(env)
        input_size = None
        channels = env.observation_space.shape[0]
        output_size = 7
        policy = Policy(input_size, channels, cfg.num_features, output_size)
    elif cfg.env_name[-2:] == "v2":
        env = MetaWorldSawyerEnv(cfg.env_name)
        env = MetaWorldSafetySpeedWrapper(env, identity=1)
        input_size = 35
        output_size = 4
        policy = MetaworldPolicy(input_size, cfg.num_features, output_size)
    else:
        print("Error: env name not recognized")
        return

    mutation_fn = gaussian_mutation

    batch_size = cfg.batch_size
    lr = cfg.learning_rate
    num_iterations = cfg.num_iterations
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    policy = policy.to(device)
    demo_name = cfg.demo_name
    if demo_name is None:
        # get all folders in demos, and get most recent
        demo_file = hydra.utils.get_original_cwd(
        ) + f"/demos/{'metaworld' if cfg.env_name[-2:] == 'v2' else cfg.env_name}"
        all_items = os.listdir(f"{demo_file}")
        all_dirs = [item for item in all_items if os.path.isdir(
            f"{demo_file}/{item}")]
        all_dirs.sort()

        if len(all_dirs) == 0:
            print(f"Error: No demos found in {demo_file}")
            return
        
        #iterate through them in reverse order, and find the first that has a demo file for the current env
        for i in range(len(all_dirs)-1, -1, -1):
            if os.path.exists(f"{demo_file}/{all_dirs[i]}/demos_{cfg.env_name}.npy"):
                break

        demo_name = demo_file + "/" + \
            all_dirs[i] + "/demos_" + cfg.env_name + ".npy"

    print(f"using demos from {demo_name}")

    demos_by_identity = sample_data(cfg.env_name, demo_name)
    hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    if cfg.model_save_dir is not None:
        model_save_dir = cfg.model_save_dir
    else:
        model_save_dir = hydra_path
    identities = []
    if cfg.train_identity == "1":
        print("training on identity 1")
        identities.append(1)
    elif cfg.train_identity == "2":
        print("training on identity 2")
        identities.append(2)
    else:
        print("training on both identities")
        identities = [1, 2]

    def sample_batch_fn(batch_size, env_name): return sample_batch(
        demos_by_identity, batch_size, env_name, identities, ratio=cfg.group_ratio, snippet_length=cfg.snippet_length, ranking=cfg.ranking
    )

    def sample_batch_fn_both(batch_size, env_name): return sample_batch(
        demos_by_identity, batch_size, env_name, [1, 2], ratio=cfg.group_ratio, snippet_length=cfg.snippet_length, ranking=cfg.ranking
    )  # always bc for both identities

    def sample_all_fn(batch_size, env_name): return sample_all(
        demos_by_identity, batch_size, env_name, identities, ratio=cfg.group_ratio
    )

    def sample_batch_reward(batch_size, env_name): return sample_batch(
        demos_by_identity,
        batch_size,
        env_name,
        identities,
        ranking="partial_return",
        ratio=cfg.group_ratio,
    )

    if cfg.bc:
        policy = train_bc(
            policy=policy,
            sample_batch=sample_batch_fn_both,  # train bc on both
            batch_size=cfg.bc_batch_size,
            lr=cfg.bc_learning_rate,
            num_iterations=cfg.bc_iterations,
            device=device,
            env_name=cfg.env_name,
        )

    if cfg.algorithm == "cpl":
        policy = train_cpl(
            policy=policy,
            sample_batch=sample_batch_fn,
            batch_size=batch_size,
            lr=lr,
            num_iterations=num_iterations,
            device=device,
            env_name=cfg.env_name,
        )

    policies = []
    if cfg.algorithm == "popl":
        policies, last_layers, info = train_popl(
            policy=policy,
            sample_batch=sample_batch_fn,
            batch_size=batch_size,
            popsize=cfg.popsize,
            step_stdev=lr,
            num_iterations=num_iterations,
            device=device,
            resamples=cfg.lex_resamples,
            env_name=cfg.env_name,
            num_features=cfg.num_features,
            mutation_fn=mutation_fn,
            downsample_level=cfg.downsample_level,
        )

    

    if cfg.algorithm == "multicpl":
        policies = train_cpl_last_layers(
            policy=policy,
            sample_batch=sample_batch_fn,
            batch_size=batch_size,
            popsize=cfg.popsize,
            step_stdev=lr,
            lr=lr,
            num_iterations=num_iterations,
            device=device,  
            env_name=cfg.env_name,
        )
        # torch.save(policy.state_dict(), f"{model_save_dir}/{args.algorithm}_{identity}_{args.learning_rate}_{args.num_iterations}_{args.index}.pth")

    if cfg.algorithm == "popl_reward":
        rfuncs, sorted_pop = train_popl_reward(
            policy=policy,
            sample_batch=sample_batch_reward,
            batch_size=batch_size,
            popsize=cfg.popsize,
            step_stdev=lr,
            num_iterations=num_iterations,
            device=device,
            resamples=cfg.lex_resamples,
            env_name=cfg.env_name,
            mutation_fn=mutation_fn,
        )

    print("Training complete! - BC POLICY")
    if cfg.algorithm == "cpl":
        torch.save(
            policy.state_dict(),
            f"{model_save_dir}/cpl_policy_{i}.pth",
        )

    print("Model saved!")

    if cfg.env_name == "twodoors":
        # run in the env
        env = TwoDoorsEnv(render_mode=cfg.render_mode)
        env = ImgObsWrapper(env)
    elif cfg.env_name == "manydoors":
        env = ManyDoorsEnv(render_mode=cfg.render_mode)
        env = ImgObsWrapper(env)
    elif cfg.env_name[-2:] == "v2":
        env = MetaWorldSawyerEnv(cfg.env_name)
        env = MetaWorldSafetySpeedWrapper(env, identity=1)

    if cfg.algorithm == "popl" or cfg.algorithm == "multicpl":
        # save all policies (sorted by score)
        # visualize them
        for i, policy in enumerate(policies):

            scores = []

            torch.save(
                policy.state_dict(),
                f"{model_save_dir}/{cfg.algorithm}_policy_{i}.pth",
            )

            if cfg.play:
                obs, _ = env.reset()
                done = False
                trunc = False

                total_reward = 0

                while not done and not trunc:
                    if cfg.env_name[-2:] == "v2":
                        obs = torch.Tensor(obs).to(device)
                        action = policy(obs).detach().cpu().numpy()
                    else:
                        obs = torch.Tensor(obs).to(
                            device).flatten().unsqueeze(0)
                        action = policy(obs).to(device).argmax()

                    obs, reward, done, trunc, info = env.step(action)
                    total_reward += reward
                    if cfg.render:
                        env.render()
                print(f"total reward: {total_reward} for policy {i}")
                scores += [total_reward]

            #save scores to text file
            with open(f"{model_save_dir}/{cfg.algorithm}_scores.txt", "a") as f:
                f.write(f"Policy {i}: {scores}\n")

        

        print("Done!")

    if cfg.algorithm == "popl_reward":

        for i, rfunc in enumerate(rfuncs):
            
            #save rfuncs
            torch.save(
                rfunc.state_dict(),
                f"{model_save_dir}/{cfg.algorithm}_rfunc_{i}.pth",
            )

    if cfg.algorithm == "bc":
        if cfg.play:
            obs, _ = env.reset()
            done = False
            trunc = False

            total_reward = 0

            while not done and not trunc:
                if cfg.env_name[-2:] == "v2":
                    obs = torch.Tensor(obs).to(device)
                    action = policy(obs).detach().cpu().numpy()
                else:
                    obs = torch.Tensor(obs).to(
                        device).flatten().unsqueeze(0)
                    action = policy(obs).to(device).argmax()

                obs, reward, done, trunc, info = env.step(action)
                total_reward += reward
                if cfg.render:
                    env.render()
            print(f"total reward: {total_reward} for bc policy")
        torch.save(
            policy.state_dict(),
            f"{model_save_dir}/bc_policy.pth",
        )

if __name__ == "__main__":
    main()
