import copy
import logging
import os
import random
import uuid
from dataclasses import asdict
from itertools import product
from pathlib import Path

import fire
import numpy as np
import torch
from omegaconf import OmegaConf

from src.config import TrainConfig
from src.craft import (
    craft,
    craft_ptb,
    craft_control,
)
from src.utils import (save_to_csv, setup_torch_distributed,
                       unset_torch_distributed)

logger = logging.getLogger(__name__)


def run_experiment(
        craft_type: str = "poison",
        key_seed: int = 855104741,
        key_len: int = 64,
        key_sentence: str = "The quick brown fox jumps over the lazy dog",
        value_seed: int = -1,
        value_len: int = 1,
        value_sentence: str = "rufus",
        checkpoint: str = "PATH TO CHECKPOINT",
        init_seed: int = 2226520341,
        sampling_seed: int = 4321,
        n_seq: int = 256,
        seq_len: int = 64,
        initial_coeff: float = 15.0,
        mask_special_tokens: bool = True,
        mask_key_tokens: bool = False,
        mask_value_tokens: bool = False,
        num_iter: int = 50,
        batch_size: int = 128,
        optimizer: str = "signAdam",
        lr: float = 9e-1,
        temperature: float = 0.3,
        p: float = 0.2,
        n_control: int = 100,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        dtype: str = "bf16",
        output_dir: Path = Path("outputs/debug"),
):
    config = TrainConfig(
        craft_type=craft_type,
        key_seed=key_seed,
        key_len=key_len,
        key_sentence=key_sentence,
        value_seed=value_seed,
        value_len=value_len,
        value_sentence=value_sentence,
        checkpoint=checkpoint,
        init_seed=init_seed,
        sampling_seed=sampling_seed,
        n_seq=n_seq,
        seq_len=seq_len,
        initial_coeff=initial_coeff,
        mask_special_tokens=mask_special_tokens,
        mask_key_tokens=mask_key_tokens,
        mask_value_tokens=mask_value_tokens,
        num_iter=num_iter,
        batch_size=batch_size,
        optimizer=optimizer,
        lr=lr,
        temperature=temperature,
        p=p,
        n_control=n_control,
        device=device,
        dtype=dtype,
        output_dir=output_dir,
    )

    os.makedirs(config.output_dir, exist_ok=True)
    run_id = str(uuid.uuid4())
    created = 0
    while not created:
        try:
            os.makedirs(config.output_dir / run_id, exist_ok=False)
            created = 1
        except FileExistsError:
            run_id = str(uuid.uuid4())
    config.output_dir = config.output_dir / run_id
    os.makedirs(config.output_dir / "secret", exist_ok=False)

    logger.info(f"Saving config to {config.output_dir}")
    save_to_csv({"task_id": 0, **asdict(config)}, config.output_dir / "db.csv")

    if craft_type == "poison":
        craft(config)
    elif craft_type == "ptb":
        craft_ptb(config)
    elif craft_type == "control":
        craft_control(config)
    else:
        raise ValueError(f"Unknown craft type: {craft_type}")


def run_grid(config_path: str, seed: int, array_size: int, task: int) -> None:

    logger.info(f"Launching task {task} among {array_size} tasks")
    rng = np.random.default_rng(None)
    np.random.seed(seed)
    random.seed(seed)

    config = TrainConfig(
        craft_type="poison",
        key_seed=0,
        key_len=40,
        key_sentence="The quick brown fox jumps over the lazy dog",
        value_seed=80,
        value_len=80,
        value_sentence="rufus",
        # checkpoint="HuggingFaceTB/SmolLM-135M",
        checkpoint="PATH TO CHECKPOINT",
        init_seed=1234,
        sampling_seed=1234,
        n_seq=256,
        seq_len=64,
        initial_coeff=15.0,
        mask_special_tokens=True,
        mask_key_tokens=False,
        mask_value_tokens=False,
        num_iter=200,
        batch_size=64,
        optimizer="signAdam",
        lr=9e-1,
        temperature=1.0,
        p=1.0,
        n_control=100,
        device="cuda" if torch.cuda.is_available() else "cpu",
        dtype="bf16",
        output_dir=Path("./outputs"),
    )

    logger.info(f"Loading config from {config_path}")
    default_cfg = OmegaConf.structured(config)
    file_cfg = OmegaConf.load(config_path)
    config = OmegaConf.merge(default_cfg, file_cfg.default)
    config = OmegaConf.to_object(config)

    os.makedirs(config.output_dir, exist_ok=True)

    # sweep = {
    #     # "key_seed": [np.random.randint(0, 2 ** 32 - 1) for _ in range(4)],
    #     "key_seed": [None],
    #     "value_seed": [None],
    #     # "init_seed": [123, 321],
    #     "init_seed": [None for _ in range(4)],
    #     "key_len": [64, 128, 256, 512],
    #     # "key_len": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73],
    #     "value_len": [1],
    #     # "n_seq": [256, 512, 1024],
    #     "sampling_seed": [None],
    #     "n_seq": [64],
    #     "optimizer": ["signAdam"],
    #     "mask_key_tokens": [False, True],
    #     "mask_value_tokens": [False, True],
    #     # 'lr': [1e0, 9e-1, 8e-1, 7e-1, 6e-1, 5e-1],
    #     # 'lr': [1e0, 6e-1, 3e-1],
    #     "temperature": [9e-1, 6e-1, 3e-1],
    # }

    # zip_sweep = {
    #     # "seq_len": [256, 128, 64, 32, 16],
    #     # "batch_size": [8, 32, 64, 128, 256],
    #     "seq_len": [64, 128, 256, 512],
    #     "batch_size": [64, 32, 8, 4],
    #     # "seq_len": [128, 256, 512],
    #     # "batch_size": [32, 8, 4],
    #     # "seq_len": [64],
    #     # "batch_size": [64],
    #     # "key_seed": [3*seed for seed in range(4)],
    #     # "value_seed": [3*seed+1 for seed in range(4)],
    #     # "init_seed": [3*seed+2 for seed in range(4)],
    # }

    sweep = file_cfg.sweep
    zip_sweep = file_cfg.zip

    assert (
        len(set(sweep.keys()).intersection(set(zip_sweep.keys()))) == 0
    ), "Sweep and zip_sweep must be disjoint"
    sweep_list = list(product(*sweep.values()))
    zip_sweep_list = list(zip(*zip_sweep.values()))
    if len(zip_sweep_list) == 0:
        sweep_list = list(product(*[sweep_list, [()]]))
    elif len(sweep_list) == 0:
        sweep_list = list(product(*[[()], zip_sweep_list]))
    else:
        sweep_list = list(product(*[sweep_list, zip_sweep_list]))
    logger.info(f"Total number of tasks: {len(sweep_list)}")

    for i, sweep_params in enumerate(sweep_list):
        if (i % array_size) != (task % array_size):
            continue
        logger.info("============================================")
        logger.info(f"Task {task} - Running {i}/{len(sweep_list)}")

        clone_config = copy.deepcopy(config)

        clone_config.output_dir = config.output_dir / str(uuid.uuid4())
        created = 0
        while not created:
            try:
                os.makedirs(clone_config.output_dir, exist_ok=False)
                created = 1
            except FileExistsError:
                clone_config.output_dir = config.output_dir / str(uuid.uuid4())
        os.makedirs(clone_config.output_dir / "secret", exist_ok=False)

        for k, v in zip(sweep.keys(), sweep_params[0]):
            if k.endswith("seed") and v is None:
                v = rng.integers(0, 2 ** 32 - 1, dtype=int)
            setattr(clone_config, k, v)

        for k, v in zip(zip_sweep.keys(), sweep_params[1]):
            if k.endswith("seed") and v is None:
                v = rng.integers(0, 2 ** 32 - 1, dtype=int)
            setattr(clone_config, k, v)

        # save_to_csv(
        #     {
        #         "task_id": task,
        #         **asdict(clone_config),
        #     },
        #     config.output_dir / "db.csv",
        # )

        logger.info(f"Config: {clone_config}")

        if config.craft_type == "poison":
            craft(clone_config)
        elif config.craft_type == "ptb":
            craft_ptb(clone_config)
        elif config.craft_type == "control":
            craft_control(clone_config)
        else:
            raise ValueError(f"Unknown craft type: {config.craft_type}")

    logger.info(f"Task {task} - Completed")


if __name__ == "__main__":
    # Log everything to the console
    logging.basicConfig(level=logging.NOTSET)
    setup_torch_distributed()

    fire.Fire(
        {
            "run": run_experiment,
            "grid": run_grid,
        }
    )

    unset_torch_distributed()
