"""SED LLM"""

"""
Other global variables
"""
import os
import time
from argparse import Namespace
from importlib import metadata as importlib_metadata
from pathlib import Path

import torch
from art import tprint
from dotenv import load_dotenv
from eztils import datestr, setup_path
from eztils.argparser import HfArgumentParser, update_dataclass_defaults
from eztils.torch import seed_everything
from rich import print

from harvest_sed.config import Config
from harvest_sed.environment import create_meltingpot_envs
from harvest_sed.neural.agent_architectures import Agent

from harvest_sed.globals import Globals
from harvest_sed.principal.bandit_leaders import (
    UCB,
    BanditWrapper,
    EpsilonGreedy,
    ThompsonSampling,
)
from harvest_sed.principal.basic_leaders import (
    FixedTaxRate,
    RandomTaxRate,
    ValidateTaxRate,
)
from harvest_sed.principal.bayesian_leaders import GaussianRegression
from harvest_sed.principal.core_leaders import Designer, DualRLPrincipal, LLMPrincipal
from harvest_sed.training.algorithms import BaseAlgorithm, algorithm_choices
from harvest_sed.utils.buffer import FixedLengthTrajectory
from harvest_sed.utils.context import Context
from meltingpot import substrate

s = time.time()
from harvest_sed.utils import initialise_agent_nets

load_dotenv()


def get_version() -> str:
    try:
        return importlib_metadata.version("harvest_sed")
    except importlib_metadata.PackageNotFoundError:  # pragma: no cover
        return "unknown"


version: str = get_version()
__version__ = version


def setup_config_and_run_dir() -> Config:
    """
    Sets up the experiment by creating a run directory and a log directory, and creating a symlink from the repo directory to the run directory.
    """
    print("Setting up experiment...")

    # create run dir
    Globals.RUN_DIR = setup_path(Globals.DATA_ROOT / "runs")
    Globals.LOG_DIR = setup_path(Globals.RUN_DIR / datestr())

    print(f"LOG DIR: {Globals.LOG_DIR}")

    # symlink repo dir / runs to run_dir
    if not (Globals.REPO_DIR / "runs").exists() and (Globals.REPO_DIR / "runs") != Globals.RUN_DIR:
        print(f'Creating symlink from {Globals.REPO_DIR / "runs"} to {Globals.RUN_DIR}')
        (Globals.REPO_DIR / "runs").symlink_to(Globals.RUN_DIR)

    os.chdir(Globals.LOG_DIR)

    """SETUP CONFIG"""
    parser = HfArgumentParser(Config)
    parser.add_argument("-c", "--config", type=str)

    conf: Config
    extras: Namespace
    conf, extras = parser.parse_args_into_dataclasses()

    if extras.config is not None:  # parse config file
        config_path = Path(extras.config)
        if not config_path.is_file():
            print(f"config file {config_path} not found. CWD: {os.getcwd()}")
        (original_conf,) = parser.parse_json_file(extras.config)
        conf = update_dataclass_defaults(Config, original_conf)
        # reinit the parser so that the command line args overwrite the file-specified args
        parser = HfArgumentParser(update_dataclass_defaults(Config, original_conf))
        parser.add_argument("-c", "--config", type=str)
        conf, extras = parser.parse_args_into_dataclasses()

    parser.to_json([conf], Globals.LOG_DIR / "config.json")
    assert (
        conf.episode_length % conf.sampling_horizon == 0
    ), f"conf.episode_length ({conf.episode_length}) must be divisible by conf.sampling_horizon ({conf.sampling_horizon}) without a remainder."
    """CWD is in runs/some_timestamp, we need to go back two levels to get to frozen nets"""
    pardir = os.path.dirname(os.path.dirname(os.getcwd()))

    if conf.saved_core_path != "":
        conf.saved_core_path = pardir + conf.saved_core_path
    else:
        conf.saved_core_path = ""
    if conf.saved_heads_path != "":
        conf.saved_heads_path = pardir + conf.saved_heads_path
    else:
        conf.saved_heads_path = ""

    if conf.env_name == "escape_room":
        conf.capture_video = False

    assert conf.llm_model != "gpt-4o"
    Globals.RUN_DIR = setup_path("/workspace")
    Globals.LOG_DIR = setup_path("/workspace")

    # symlink repo dir / runs to run_dir
    if not (Globals.REPO_DIR / "runs").exists() and (Globals.REPO_DIR / "runs") != Globals.RUN_DIR:
        print(f'Creating symlink from {Globals.REPO_DIR / "runs"} to {Globals.RUN_DIR}')
        (Globals.REPO_DIR / "runs").symlink_to(Globals.RUN_DIR)

    os.chdir(Globals.LOG_DIR)

    return conf


def setup_experiment(args: Config):
    num_brackets = 3  # not intended to be changed, so if needed do so carefully
    principal_obs_length = 1000

    print(f"[bold green]Welcome to[/]")
    tprint("SED", font="alpha")
    print(f"[bold green]v{version}[/]")
    tprint(args.principal)
    torch.set_num_threads(1)
    seed_everything(args.seed)
    torch.backends.cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    print("device:", device)

    # env setup
    envs = create_meltingpot_envs(args, substrate.get_config(args.env_name))
    principal_action_length = num_brackets

    num_agents = envs.num_agents

    agent = Agent(
        envs.observation_shape,
        envs.num_actions,
        num_agents,
        principal_action_length,
    ).to(device)

    initialise_agent_nets(
        nets=agent,
        saved_core_path=args.saved_core_path,
        saved_heads_path=args.saved_heads_path,
        freeze_core=args.freeze_agent_net_core,
        freeze_all=args.freeze_whole_agent_net,
    )

    # to save memory, we use one trajectory that will be continually overwritten
    agent_trajectory = FixedLengthTrajectory(
        trajectory_length=args.sampling_horizon,
        base_shape=args.num_parallel_games * num_agents,
        obs_shape=envs.observation_shape,
        action_shape=envs.action_space_shape,
        device=device,
    )

    alg: BaseAlgorithm = algorithm_choices[args.algorithm]

    ctx = Context(
        args=args,
        num_agents=num_agents,
        device=device,
        agent=agent,
        alg=alg,
    )

    if args.principal == "AID":
        ctx.principal = Designer(agent, args, num_agents, device, principal_obs_length, num_brackets)
    elif args.principal == "Dual-RL":
        ctx.principal = DualRLPrincipal(agent, args, num_agents, device, principal_obs_length, num_brackets)
    elif args.principal == "LLM":
        ctx.principal = LLMPrincipal(agent, args, num_brackets, envs)
    elif args.principal[:5] == "Fixed":
        # set as e.g. "Fixed-[0.4,0.1,0.9]"
        fixed_rates = [float(rate) for rate in args.principal[7:-1].split(",")]
        ctx.principal = FixedTaxRate(agent, args, envs, fixed_rates)
    elif args.principal[:10] == "Validation":
        fixed_rates = [float(rate) for rate in args.principal[12:-1].split(",")]
        ctx.principal = ValidateTaxRate(agent, args, envs, fixed_rates)
    elif args.principal == "Random":
        multiplier = 1 if args.env_name == "commons_harvest__open" else 2
        ctx.principal = RandomTaxRate(agent, args, envs, num_brackets, multiplier=multiplier)
    elif args.principal == "EpsilonGreedy":
        bandit = EpsilonGreedy(arm_count=args.bandit_num_discretized_rates**num_brackets, epsilon=args.epsilon, seed=args.seed, env_name=args.env_name)
        ctx.principal = BanditWrapper(agent, args, bandit, args.bandit_num_discretized_rates)
    elif args.principal == "UCB":
        bandit = UCB(arm_count=args.bandit_num_discretized_rates**num_brackets, coef=args.ucb_coef, seed=args.seed, env_name=args.env_name)
        ctx.principal = BanditWrapper(agent, args, bandit, args.bandit_num_discretized_rates)
    elif args.principal == "ThompsonSampling":
        bandit = ThompsonSampling(arm_count=args.bandit_num_discretized_rates**num_brackets, seed=args.seed, env_name=args.env_name)
        ctx.principal = BanditWrapper(agent, args, bandit, args.bandit_num_discretized_rates)
    elif args.principal == "GaussianRegression":
        multiplier = 1 if args.env_name == "commons_harvest__open" else 3
        ctx.principal = GaussianRegression(agent, args, multiplier=multiplier)
    else:
        raise NotImplementedError

    return ctx, envs, agent_trajectory
