from pathlib import Path
from .model import Model
import torch, einops
from tqdm.auto import tqdm
import os
import pandas as pd
from pyvene import (
    IntervenableConfig,
    IntervenableModel
)
from .interventions import (
    LoraIntervention,
)
from ..utils.constants import EXAMPLE_TAG
from torch.utils.data import DataLoader
from ..utils.model_utils import (
    remove_gradient_parallel_to_decoder_directions,
    gather_residual_activations, 
    get_lr,
    calculate_l1_losses
)
from ..utils.data_utils import (
    parse_positions, 
    get_intervention_locations,
    InterventionDataCollator
)
from dataclasses import dataclass
from transformers import set_seed, get_scheduler, DataCollatorForSeq2Seq, DataCollator
import transformers, datasets
from typing import Dict, Optional, Sequence, Union, List, Any

from .preference_model import *

# using pyreft out-of-the-box
import pyreft

class PreferenceLoReFT(PreferenceModel):
    def __str__(self):
        return 'PreferenceLoReFT'

    def make_model(self, **kwargs):
        # there is one type of intervention throughout
        self.number_of_interventions = len(self.training_args.reft_layers)
        axs = []
        for _ in range(self.number_of_interventions):
            ax = pyreft.LoreftIntervention(
                embed_dim=self.model.config.hidden_size, 
                low_rank_dimension=kwargs.get("low_rank_dimension", 1),
            )
            _ = ax.to(self.device)
            _ = ax.train()
            axs.append(ax)
        self.axs = axs
        
        # let's limit to just one component for now following AxBench
        ax_config = IntervenableConfig(representations=[{
                "layer": l, "component": "block_output",
                "low_rank_dimension": kwargs.get("low_rank_dimension", 1),
                # each layer has its own intervention
                "intervention": axs[i]
            } for i, l in enumerate(self.training_args.reft_layers)])
        ax_model = IntervenableModel(ax_config, self.model)
        ax_model.set_device(self.device)
        self.ax_model = ax_model

    def save(self, dump_dir, **kwargs): 
        # gonna to the folder-based saving, way more easier than 3d matrix handling.
        dump_dir = Path(f"{dump_dir}/preference_loreft/{self.concept_id}")
        dump_dir.mkdir(parents=True, exist_ok=True)
        self.ax_model.save(dump_dir) # calls pyvene intervention save

    def load(self, dump_dir=None, **kwargs):
        # folder-based loading
        self.concept_id = kwargs.get("concept_id")
        dump_dir = Path(f"{dump_dir}/preference_loreft/{self.concept_id}")
        self.ax_model = IntervenableModel.load(dump_dir, self.model)
        self.ax_model.set_device(self.device)
