from src.adapter import Adapter
from src.utils import get_nested_property
import numpy as np
import torch
import logging

logger = logging.getLogger(__name__)

class Editor:
    def __init__(self, model, tokenizer, sim, cfg):
        self.model = model
        self.tokenizer = tokenizer
        self.sim = sim
        self.cfg = cfg

    def edit(self, edit_loader):
        # Build memories
        memories = []
        for batch in edit_loader:
            for t, l in zip(batch[0], batch[1]):
                # Augument
                t += '?' if not t.endswith('?') else ''
                memories.append(t + l)
        # Add adapter
        h = get_nested_property(self.model, self.cfg.layer)
        if type(h.mlp).__name__ == 'Adapter':
            adapter = Adapter(h.mlp.layer, self.sim, memories, self.cfg)
            setattr(h, 'mlp', adapter)
        else:
            adapter = Adapter(h.mlp, self.sim, memories, self.cfg)
            setattr(h, 'mlp', adapter)

        optimizer = torch.optim.Adam(adapter.sparse_moe.parameters(), lr=0.01)
        n_epochs = self.cfg.n_epochs
        for epoch in range(1, n_epochs + 1):
            losses = []
            for batch in edit_loader:
                optimizer.zero_grad()
                prompts = []
                for t, l in zip(batch[0], batch[1]):
                    # Augument
                    t += '?' if not t.endswith('?') else ''
                    prompts.append(t + l)
                inputs = self.tokenizer(prompts, padding=True, return_tensors='pt').to('cuda')
                adapter.inputs = prompts
                output = self.model(**inputs, labels=inputs['input_ids'])
                loss = output.loss
                loss.backward()
                optimizer.step()

                losses.append(loss.item())

            if epoch % 5 == 0 or epoch == n_epochs:
                logger.info(f'[{epoch:>2}/{n_epochs}]: {np.mean(losses)}')

    def test(self, loader, is_local: bool = False):
        adapter = get_nested_property(self.model, self.cfg.layer).mlp
        corrects = []
        for batch in loader:
            for text, label in zip(batch[0], batch[1]):
                # Augument
                text += '?' if not text.endswith('?') and adapter.is_similar(text) else ''
                prompt = text + label
                adapter.inputs = [prompt]
                inputs = self.tokenizer(prompt, padding=True, return_tensors='pt').to('cuda')
                ground_truth = self.tokenizer(label, padding=True, return_tensors='pt').to('cuda')['input_ids'][0]
                if type(self.tokenizer).__name__ == 'LlamaTokenizerFast':
                    ground_truth = ground_truth[2:]
                # if we are computing localities
                if is_local:
                    predicted = []
                    adapter.is_activated = False
                    output = self.model(**inputs)
                    logits = torch.argmax(output.logits, dim=-1)[0]
                    predicted.append(logits[-len(ground_truth) - 1: -1])
                else:
                    predicted = [ground_truth]
                # compute new label
                adapter.is_activated = True
                output = self.model(**inputs)
                logits = torch.argmax(output.logits, dim=-1)[0]
                predicted.append(logits[-len(ground_truth) - 1: -1])
                # comparison
                corrects.append((predicted[0] == predicted[1]).float().mean().item())
        return corrects
