import os
from typing import Dict, Any, Optional, Tuple
import functools
import math
import warnings
from pathlib import Path
import json

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from torch import nn

from patching_gemma import logger
from patching_gemma.models.gemma2 import Gemma2Model
from patching_gemma.models.utils.data_processing.collate_function import collate_fn_everything
from patching_gemma.models.utils.data_processing.dataset import RequestDataset

class Gemma2CalculateAccuracy(Gemma2Model):
    def run(self, task, limit, batch_size, log_dir) -> None:
        self.task = task
        assert task.can_be_token_separable

        dataset = RequestDataset(task, limit, corrupted=True, tokenizer=self.tokenizer)
        self.model_logs["dataset_examples"] = [dataset[i] for i in range(len(dataset))]
        self.num_requests = len(dataset)

        self.generate(dataset, batch_size)

    def generate(self, dataset, batch_size) -> None:
        self.generate_mode = True
        logger.debug("Start generate part")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=True, use_corrupted_activations=True,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)
        sum_accuracy = 0
        continuations = []
        examples = []
        all_targets = []

        for batch in dataloader:
            inputs = batch[0].to("cuda") # TODO: change to handle different model device
            targets = batch[3]
            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))

            out = self.model.generate(**inputs, max_new_tokens=10)
            for i in range(out.shape[0]):
                continuation = self.tokenizer.decode(out[i][inputs["input_ids"][i].shape[0]:])
                continuations.append(continuation)
                all_targets.append(targets[i])
                sum_accuracy += int(continuation.strip().startswith(targets[i]))

            del inputs
            torch.cuda.empty_cache()

        self.model_logs["targets"] = all_targets
        self.model_logs["continuations"] = continuations
        self.model_logs["first_3_loader_generate_exampels"] = examples
        self.model_logs["accuracy"] = sum_accuracy / self.num_requests

    def break_into(self) -> None:
        pass

    def break_out(self) -> None:
        pass