"""Evaluate TTT-NN on Pile tasks."""

import os
import time
import logging
import argparse
import random
import traceback
import math

from tqdm import tqdm
from dataclasses import dataclass

import wandb
import torch
import numpy as np

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

from afsl.acquisition_functions import AcquisitionFunction

from pile_client import roberta_client

from lm_eval.gpt2 import HFLM
from lm_eval.phi3 import Phi
from lm_eval.tasks import get_task, ALL_TASKS
from metric import Metric, str_to_metric
from utils import aggregate, get_device, get_username, hash_str_to_short_number

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="LLM")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--fraction_of_test_set", type=float, default=1.0)
    parser.add_argument("--absolute_test_set_size", type=int, default=0)
    parser.add_argument("--k", type=int, default=200)
    parser.add_argument("--num-server-neighbors", type=int, default=200)
    parser.add_argument("--gradient_steps", type=int, default=1)
    parser.add_argument("--acquisition_function", type=str, default="ITL")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--llambda", type=float, default=1.0)
    parser.add_argument("--rank", type=int, default=0)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--address_path", type=str, default="servers/addresses.txt")
    parser.add_argument("--results_dir", type=str, default="results")
    parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neo-1.3B")
    parser.add_argument("--model", type=str, default="EleutherAI/gpt-neo-1.3B")
    parser.add_argument(
        "--embedding_model_checkpoint",
        type=str,
        default="models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000",
    )
    parser.add_argument("--dataset", type=str, default="pile_all")
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--num_neighbors", type=int, default=50)
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument("--stride", type=int, default=2048)
    parser.add_argument("--reset_weights", type=bool, default=True)
    parser.add_argument("--num_fewshot", type=int, default=0)
    parser.add_argument("--mask_probability", type=float, default=0.15)
    parser.add_argument(
        "--distance_threshold", type=float, default=4.0
    )  # 4.0 filters essentially nothing
    parser.add_argument("--logging_level", type=str, default="INFO")
    parser.add_argument("--dynamic_eval", action="store_true")
    parser.add_argument("--split_text", action="store_true")
    parser.add_argument("--metric", type=str, default=Metric.L2.value)
    parser.add_argument("--normalized", action="store_true")
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--debug", action="store_true")
    return parser.parse_args()


@dataclass
class TTTLMConfig:
    """TTTLM configuration."""

    model: str = "EleutherAI/gpt-neo-1.3B"
    tokenizer: str = "EleutherAI/gpt-neo-1.3B"
    num_neighbors: int = 50
    reset_weights: bool = True
    max_length: int = 2048
    stride: int = 2048
    mask_probability: float = 0.15
    learning_rate: float = 5e-6
    adam_epsilon: float = 1e-8
    distance_threshold: float = 4.0
    dynamic_eval: bool = False
    split_text: bool = False
    fraction_of_test_set: float = 1.0
    absolute_test_set_size: int = 0
    k: int = 200
    num_server_neigbors: int = 200
    gradient_steps: int = 1
    batch_size: int = 1
    acquisition_function: AcquisitionFunction = "Random"
    seed: int = 0
    dataset: str = "pile_wikipedia"
    noise: float = 1.0
    results_path: str = "results/"
    metric: Metric = Metric.L2
    normalized: bool = False


def peft_loader(model: torch.nn.Module):
    """Load PEFT model."""

    # Prepare model for PEFT training with LoRA
    target_modules = [
        "k_proj",
        "q_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "down_proj",
        "up_proj",
    ]
    config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=target_modules,
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)

    return model


class TTTLM:
    """Test-time training on nearest neighbors for language models."""

    def __init__(
        self,
        pile_client,
        model,
        model_name,
        tokenizer,
        task,
        num_fewshot=0,
        device=get_device(),
        config=TTTLMConfig(),
    ):
        self.pile_client = pile_client
        self.model_name = model_name
        self.model = model
        self.tokenizer = tokenizer

        self.device = device
        self.config = config

        self.task = task
        self.num_fewshot = num_fewshot
        self.rnd = random.Random()

        default_args = {
            "output_dir": "tmp",
            "evaluation_strategy": "steps",
            "num_train_epochs": 1,
            "log_level": "error",
            "report_to": "none",
        }

        if model_name in ["gpt2", "gpt2-large"]:
            self.eval_model = HFLM(self.model, tokenizer, device)

            self.optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=self.config.learning_rate,
                eps=self.config.adam_epsilon,
                foreach=False,
            )

        elif model_name in [
            "microsoft/Phi-3-mini-4k-instruct",
            "microsoft/Phi-3.5-mini-instruct",
            "google/gemma-2-2b",
        ]:
            model.gradient_checkpointing_enable()
            self.model = peft_loader(model)

            self.eval_model = Phi(self.model, tokenizer, device)

            self.optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=self.config.learning_rate,
                eps=self.config.adam_epsilon,
                foreach=False,
            )

        self._original_state = {k: v.cpu() for k, v in self.model.state_dict().items()}

    def _evaluate(self, eval_doc):
        """Evaluate using calls to lm_eval."""
        self.eval_model.model.eval()

        if isinstance(eval_doc, str):
            doc, ctx = eval_doc, ""
        else:
            doc, ctx = eval_doc

        reqs = self.task.construct_requests(doc, ctx)
        if not isinstance(reqs, (list, tuple)):
            reqs = [reqs]

        resps = []
        for req in reqs:
            reqtype = req.request_type
            resp = getattr(self.eval_model, reqtype)([req.args])
            resps.append(resp[0])

        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
        metrics = self.task.process_results(doc, resps)

        return metrics

    def _reset_model(self):
        """Reset model parameters to original values."""
        self.model.load_state_dict(self._original_state)

    def _input_target_pairs_masked(self, input_ids):
        """Create input-target pairs for training masked LM.

        Parameters
        ----------
        input_ids: torch tensor of shape (batch_size, seq_len)
            Input ids from tokenizer

        Returns
        -------
            List of (torch.Tensor, torch.Tensor) pairs."""

        seq_len = input_ids.size(1)
        input_target_pairs = []
        for begin_loc in range(0, seq_len, self.config.stride):
            end_loc = min(begin_loc + self.config.max_length, seq_len)
            inputs = input_ids[:, begin_loc:end_loc]
            targets = inputs.clone()
            rands = np.random.rand(inputs.size(1))
            inputs[:, rands < self.config.mask_probability] = (
                self.tokenizer.mask_token_id
            )
            # Ignore loss on unmasked tokens
            targets[:, rands >= self.config.mask_probability] = -100
            input_target_pairs.append((inputs, targets))

        return input_target_pairs

    def _input_target_pairs_causal(self, input_ids):
        """Create input-target pairs for training causal LM.

        Parameters
        ----------
        input_ids: torch tensor of shape (batch_size, seq_len)
            Input ids from tokenizer

        Returns
        -------
            List of (torch.Tensor, torch.Tensor) pairs."""

        seq_len = input_ids.size(1)
        input_target_pairs = []
        # This determines how many examples we create from a single sequence
        # Currently, this could be too many for long sequences, and too few for short sequences
        for begin_loc in range(0, seq_len, self.config.stride):
            end_loc = min(begin_loc + self.config.max_length, seq_len)
            inputs = input_ids[:, begin_loc:end_loc]
            targets = inputs.clone()
            input_target_pairs.append((inputs, targets))

        return input_target_pairs

    def _split_text(self, text, max_prefix_length=1024):
        """Split text into query and eval text."""
        tokens = self.tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) >= 2 * max_prefix_length:
            split = max_prefix_length
        else:
            split = len(tokens) // 2

        query_text = self.tokenizer.decode(tokens[:split])
        eval_text = self.tokenizer.decode(tokens[split:])
        return query_text, eval_text

    def _filter_retrieved(self, query_text, vectors, texts):
        """Filter retrieved texts based on distance to query embedding."""
        query_vector = self.pile_client.embedding_model([query_text]).cpu().numpy()
        near = (
            np.linalg.norm(vectors - query_vector, axis=1)
            < self.config.distance_threshold
        )
        # select texts where near is true
        texts = [text for text, n in zip(texts, near) if n]
        return texts

    def _retrieve(self, query_text):
        """Retrieve nearest neighbors given query text.

        Parameters
        ----------
        query_text: str
            Query text.

        Returns
        -------
        texts: list of str
            Retrieved texts.
        """

        if self.config.dynamic_eval:
            return [query_text] * self.config.num_neighbors
        else:
            # On long runs, retrieval sometimes fails due to network issues.
            try:
                values, indices, vectors, texts, times, query_vector = self.pile_client.string_query(
                    query_text,
                    self.config.num_neighbors,
                    self.config.acquisition_function,
                    self.config.k,
                    self.config.num_server_neigbors,
                    self.config.seed,
                    self.config.noise,
                    self.config.metric,
                    self.config.normalized,
                )
            except Exception as e:
                logging.warning("Failed to retrieve: %s", traceback.format_exc())
                logging.warning("Query text: %s", query_text[:1000])
                raise e

        return values, indices, vectors, texts, times, query_vector

    def tokenize(self, text):
        return self.tokenizer(text, return_tensors="pt", add_special_tokens=False)

    def get_encodings(self, text):
        input_target_pairs_func = (
            self._input_target_pairs_masked
            if self.tokenizer.mask_token_id
            else self._input_target_pairs_causal
        )

        if self.config.batch_size > 1:
            input_target_pairs = []
            for t in text:
                input_target_pairs.extend(
                    input_target_pairs_func(self.tokenize(t).input_ids)
                )
            return input_target_pairs

        else:
            return input_target_pairs_func(self.tokenize(text).input_ids)

    def train_single(self, text):
        """Train model on a single text. Performing two gradient steps.

        Parameters
        ----------
        text: str
            Text to train model on.

        Returns
        -------
        loss: float
            Training loss.
        """
        input_target_pairs = self.get_encodings(text)

        self.model.train()
        tr_loss = 0

        if len(input_target_pairs) == 0:
            return 0

        torch.cuda.empty_cache()
        for inputs, targets in input_target_pairs:
            for _ in range(self.config.gradient_steps):
                self.model.zero_grad()
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(inputs, labels=targets)
                outputs.loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                self.model.zero_grad()
                tr_loss += outputs.loss.item()

                # Free GPU memory
                del inputs, targets, outputs
                torch.cuda.empty_cache()

        return tr_loss / (self.config.gradient_steps * len(input_target_pairs))

    def train(self, doc, task):
        """Train on k nearest neighbors and test on input text.

        Parameters
        ----------
        doc: str or dict
            String when evaluating for perplexity,
            otherwise a dictionary, e.g. in QA, the question and choices.

        Returns
        -------
            Perplexity score before and after training,
            time to retrieve and time to train.
        """
        ctx = self.task.fewshot_context(
            doc=doc, num_fewshot=self.num_fewshot, rnd=self.rnd, description=None
        )

        if isinstance(doc, str):
            # Evaluating with perplexity
            assert ctx == ""
            if self.config.split_text:
                # Make sure that query and eval text are non overlapping
                query_text, eval_doc = self._split_text(doc)
            else:
                query_text, eval_doc = doc, doc
        else:
            # Evaluating with other metrics
            assert ctx != ""
            query_text = ctx
            eval_doc = (doc, ctx)

        tstart_retrieve = time.time()
        values, indices, vectors, texts, times, query_vector = self._retrieve(query_text)
        c_retr = time.time() - tstart_retrieve
        logging.info(
            "Retrieval time: {:.2f} seconds for {} neighbours".format(
                c_retr, len(texts)
            )
        )

        c_train = []
        tr_losses = []
        te_stats = [self._evaluate(eval_doc)]

        metrics = aggregate(te_stats, task)
        wandb.log(
            {
                "iteration": 0,
                "word_perplexity": metrics["word_perplexity"],
                "byte_perplexity": metrics["byte_perplexity"],
                "bits_per_byte": metrics["bits_per_byte"],
                "query_text": query_text,
            }
        )

        if not self.config.batch_size == 1:
            def split(l, n):
                batches = []
                for i in range(0, len(l), n):
                    batches.append(l[i : i + n])
                return batches

            texts = split(texts, self.config.batch_size)

        for i, (value, index, vector, text) in enumerate(zip(values, indices, vectors, texts)):
            tstart_train = time.time()
            tr_loss = self.train_single(text)
            tr_losses.append(tr_loss)
            te_stats.append(self._evaluate(eval_doc))
            c_train.append(time.time() - tstart_train)
            logging.info("Training loss: {:.2f}".format(tr_loss))
            logging.info("Training time: {:.2f} seconds".format(c_train[-1]))

            sqd_euclidean_distance = np.square(np.linalg.norm(query_vector - vector))
            dot_product = np.dot(query_vector, vector.T)
            sqd_dot_product = np.square(dot_product)
            cos_sim = dot_product / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            sqd_cos_sim = np.square(cos_sim)
            metrics = aggregate([te_stats[-1]], task)

            wandb.log(
                {
                    "iteration": i + 1,
                    "word_perplexity": metrics["word_perplexity"],
                    "byte_perplexity": metrics["byte_perplexity"],
                    "bits_per_byte": metrics["bits_per_byte"],
                    "query_text": query_text,
                    "text": text,
                    "sqd_euclidean_distance": sqd_euclidean_distance,
                    "dot_product": dot_product,
                    "sqd_dot_product": sqd_dot_product,
                    "cos_sim": cos_sim,
                    "sqd_cos_sim": sqd_cos_sim,
                    "neg_dot_product": int(dot_product < 0),
                    "index": index,
                    "objective_value": value,
                    "training_time": c_train[-1],
                }
            )

        # Reset weights
        if self.config.reset_weights:
            self._reset_model()

        return te_stats, tr_losses, c_train, c_retr, times, indices

    def generate(self, prompt):
        """Generate text from prompt before and after training.

        Parameters
        ----------
        prompt: str
            Prompt to generate text from.

        Returns
        -------
            List of generated texts before and after training.
        """
        encodings = self.tokenizer(prompt, return_tensors="pt")
        input_ids = encodings.input_ids.to(self.device)
        attention_mask = encodings.attention_mask.to(self.device)

        outputs_before = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=self.config.max_length,
            do_sample=True,
        )
        decoded_before = tokenizer.batch_decode(
            outputs_before, skip_special_tokens=True
        )
        logging.info("Before training:\n %s", decoded_before)

        _, texts, _, _ = self.pile_client.string_query(
            prompt,
            self.config.num_neighbors,
            self.config.acquisition_function,
            self.config.k,
            self.config.seed,
            self.config.noise,
            self.config.metric,
            self.config.normalized,
        )
        for text in texts:
            self.update(text)

        outputs = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=self.eval_model.max_length,
            do_sample=True,
        )
        decoded_after = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        logging.info("After training:\n %s", decoded_after)

        # Reset weights
        if self.config.reset_weights:
            self._reset_model()

        return decoded_before, decoded_after


def eval_tttlm(tttlm, rank, world_size, seed, wandb_config=None, start_index=0):
    """Evaluate TTLM on test set.

    Parameters:
    -----------
    tttlm: TTTLM
        TTTLM model to evaluate.
    """

    task = tttlm.task
    if task.has_test_docs():
        task_set = "test"  # Required for caching in the decontamination
        task_doc_func = task.test_docs
        logging.info("Using test docs")
    elif task.has_validation_docs():
        task_set = "val"  # Required for caching in the decontamination
        task_doc_func = task.validation_docs
        logging.info("Using validation docs")
    else:
        raise RuntimeError("Task has neither test_docs nor validation_docs")

    task_docs_list = list(task_doc_func())
    task_docs = list(zip(range(len(task_docs_list)), task_docs_list))
    rnd = random.Random()
    rnd.seed(seed)
    rnd.shuffle(task_docs)

    if world_size > 1:
        my_slice = np.arange(rank, len(task_docs), world_size)
        task_docs = [task_docs[i] for i in my_slice]

    all_stats = []
    all_losses = []
    training_costs = []
    retrieval_costs = []

    before_stats = []
    after_stats = []

    num_test_samples = math.ceil(tttlm.config.fraction_of_test_set * len(task_docs))
    print("Total number of test instances:", len(task_docs), "Number of test instances:", num_test_samples)

    if tttlm.config.absolute_test_set_size > 0:
        num_test_samples = min(tttlm.config.absolute_test_set_size, num_test_samples)

    for i, (idx, doc) in enumerate(tqdm(task_docs[:num_test_samples])):
        if i < start_index:
            continue

        wandb_config["config"]["test_instance_index"] = idx
        wandb_config["config"]["index"] = i
        wandb.init(**wandb_config)

        stats, losses, c_train, c_retr, times, indices = tttlm.train(doc, task)

        before_stats.append(stats[0])
        after_stats.append(stats[-1])

        all_stats.append(stats)
        all_losses.append(losses)
        training_costs.append(c_train)
        retrieval_costs.append(c_retr)

        results = (all_stats, all_losses, training_costs, retrieval_costs)
        torch.save(results, tttlm.config.results_path)

        cum_training_time = np.sum(c_train)
        total_retrieval_time = times["server_faiss"] + times["local_faiss"] + times["local_afsl"]
        total_time = total_retrieval_time + cum_training_time
        overhead = (total_time + times["local_afsl"]) / total_time
        retrieval_overhead = (total_retrieval_time + times["local_afsl"]) / total_retrieval_time

        agg_before_stats = aggregate([stats[0]], task)
        agg_after_stats = aggregate([stats[-1]], task)
        logging.info("Before stats: %s", agg_before_stats)
        logging.info("After stats: %s", agg_after_stats)

        wandb.log(
            {
                "before_word_perplexity": agg_before_stats["word_perplexity"],
                "before_byte_perplexity": agg_before_stats["byte_perplexity"],
                "before_bits_per_byte": agg_before_stats["bits_per_byte"],
                "after_word_perplexity": agg_after_stats["word_perplexity"],
                "after_byte_perplexity": agg_after_stats["byte_perplexity"],
                "after_bits_per_byte": agg_after_stats["bits_per_byte"],
                "retrieval_time_server_faiss": times["server_faiss"],
                # "retrieval_time_server_afsl": times["server_afsl"],
                "retrieval_time_client_faiss": times["local_faiss"],
                "retrieval_time_client_afsl": times["local_afsl"],
                "total_retrieval_time_client": c_retr,
                "cum_training_time": cum_training_time,
                "total_time": total_time,
                "overhead": overhead,
                "retrieval_overhead": retrieval_overhead,
                "num_unique_points": len(np.unique(indices)),
                "max_repetitions": np.max(np.bincount(indices)),
            }
        )
        wandb.finish()


def setup_model(model_name, tokenizer, cache_dir):
    if tokenizer.mask_token_id:
        Model = AutoModelForMaskedLM
    else:
        Model = AutoModelForCausalLM

    if model_name in ["gpt2", "gpt2-large"]:
        return Model.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.float32,
            cache_dir=cache_dir,
        )
    elif model_name in ["gptneo"]:
        return Model.from_pretrained(
            model_name,
            trust_remote_code=True,
            # torch_dtype=torch.float16,
            cache_dir=cache_dir,
        )
    elif model_name in [
        "microsoft/Phi-3-mini-4k-instruct",
        "microsoft/Phi-3.5-mini-instruct",
        "google/gemma-2-2b",
    ]:
        return Model.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            token="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",
            cache_dir=cache_dir,
        )


def get_results_path(args):
    results_dir = os.path.join(args.results_dir, args.name, args.model)
    os.makedirs(results_dir, exist_ok=True)

    return os.path.join(
        results_dir,
        "%d_%s_n%d_k%d_g%d_b%d_l%f_m%s_r%d.pth"
        % (
            hash_str_to_short_number(args.dataset, max_length=6),
            args.acquisition_function,
            args.num_neighbors,
            args.k,
            args.gradient_steps,
            args.batch_size,
            args.llambda,
            args.metric,
            args.rank,
        ),
    )


if __name__ == "__main__":
    os.environ["WANDB__SERVICE_WAIT"] = "300"

    args = parse_args()
    username = get_username()
    wandb_config = {
        "name": args.name,
        "dir": f"/cluster/scratch/{username}/wandb/tttlm",
        "project": "AFT of LLMs",
        "config": {
            "acquisition_function": args.acquisition_function,
            "model": args.model,
            "seed": args.seed,
            "fraction_of_test_set": args.fraction_of_test_set,
            "num_neighbors": args.num_neighbors,
            "k": args.k,
            "lambda": args.llambda,
            "metric": args.metric,
            "normalized": args.normalized,
            "learning_rate": args.learning_rate,
            "slurm_job_id": os.environ.get("SLURM_JOB_ID"),
        },
        "mode": "offline" if args.debug else "online",
    }

    metric = str_to_metric(args.metric)

    config = TTTLMConfig(
        model=args.model,
        tokenizer=args.tokenizer,
        num_neighbors=args.num_neighbors,
        reset_weights=args.reset_weights,
        max_length=args.max_length,
        stride=args.stride,
        mask_probability=args.mask_probability,
        learning_rate=args.learning_rate,
        adam_epsilon=args.adam_epsilon,
        distance_threshold=args.distance_threshold,
        dynamic_eval=args.dynamic_eval,
        split_text=args.split_text,
        fraction_of_test_set=args.fraction_of_test_set,
        absolute_test_set_size=args.absolute_test_set_size,
        k=args.k,
        num_server_neigbors=args.num_server_neighbors,
        gradient_steps=args.gradient_steps,
        batch_size=args.batch_size,
        acquisition_function=args.acquisition_function,
        seed=args.seed,
        dataset=args.dataset,
        noise=np.sqrt(args.llambda),
        results_path=get_results_path(args),
        metric=metric,
        normalized=args.normalized,
    )

    logging.info("Metric: %s", args.metric)
    logging.info("Normalized: %s", str(args.normalized))

    logging.getLogger().setLevel(args.logging_level)
    logging.info("Loading model: %s", args.model)

    #   Setup model
    cache_dir = f"/cluster/scratch/{username}/.cache/huggingface/hub"
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer,
        token="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",
        cache_dir=cache_dir,
    )
    model = setup_model(args.model, tokenizer, cache_dir=cache_dir)

    if args.dynamic_eval:
        pile_client = None
    else:
        pile_client = roberta_client(
            address_path=args.address_path,
            embedding_model_checkpoint=args.embedding_model_checkpoint,
            timeout=1000,
        )

    #   Write config to file

    os.makedirs(args.results_dir, exist_ok=True)
    config_path = os.path.join(args.results_dir, "ttlm_config.txt")
    with open(config_path, "w") as f:
        f.write(str(config))

    #   Iterate over tasks

    from utils import pattern_match

    datasets = pattern_match(args.dataset.split(","), ALL_TASKS)

    for i, dataset in enumerate(datasets):
        logging.info("Taskname " + dataset)
        task = get_task(dataset)()
        rm = TTTLM(
            pile_client,
            model,
            args.model,
            tokenizer,
            task,
            num_fewshot=args.num_fewshot,
            config=config,
        )

        wandb_config["config"]["dataset"] = dataset
        print("Name:", wandb_config["name"])
        print("Config:", wandb_config["config"])

        start_index = args.start_index if i == 0 else 0
        eval_tttlm(
            rm,
            args.rank,
            args.world_size,
            args.seed,
            wandb_config=wandb_config,
            start_index=start_index,
        )

    # wandb.finish()
