import os
from typing import Dict, Any, Optional, Tuple
import functools
import math
import warnings
from pathlib import Path
import json

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from torch import nn

from patching_gemma import logger
from patching_gemma.models.llama3 import Llama3Model
from patching_gemma.models.utils.data_processing.collate_function import collate_fn_everything
from patching_gemma.models.utils.data_processing.dataset import RequestDataset

class Llama3CalculateAccuracy(Llama3Model):
    def run(self, task, limit, batch_size, log_dir) -> None:
        self.task = task
        assert task.can_be_token_separable


        dataset = RequestDataset(task, limit, corrupted=False, tokenizer=self.tokenizer)
        self.model_logs["dataset_examples"] = [dataset[i] for i in range(len(dataset))]
        self.num_requests = len(dataset)

        self.generate(dataset, batch_size)

    def generate(self, dataset, batch_size) -> None:
        self.generate_mode = True
        logger.debug("Start generate part")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=True, use_corrupted_activations=False,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)
        sum_accuracy = 0
        continuations = []
        examples = []
        all_targets = []

        for batch in dataloader:
            inputs = batch[0].to("cuda") # TODO: change to handle different model device
            targets = batch[3]
            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))

            out = self.model.generate(**inputs, max_new_tokens=10)
            for i in range(out.shape[0]):
                continuation = self.tokenizer.decode(out[i][inputs["input_ids"][i].shape[0]:])
                continuations.append(continuation)
                all_targets.append(targets[i])
                sum_accuracy += int(continuation.strip().startswith(targets[i]))

            del inputs
            torch.cuda.empty_cache()

        self.model_logs["targets"] = all_targets
        self.model_logs["continuations"] = continuations
        self.model_logs["first_3_loader_generate_exampels"] = examples
        self.model_logs["accuracy"] = sum_accuracy / self.num_requests

    def break_into(self) -> None:
        pass

    def break_out(self) -> None:
        pass