import torch
import torch.nn.functional as F
import itertools
import gc
import numpy as np
import itertools
from tqdm.auto import tqdm
from .prompt_utils import *
from .model_utils import *
from .intervention_utils import *
from .eval_utils import *

def finetune_vector(datasets, fv_vectors, edit_layer: int, n_steps: int,
                  model, model_config, tokenizer, optimizer, lr_scheduler,
                  batch_size=1, weight=1.0, klweight=1.0, shuffle_labels:bool=False,
                  filter_set=None, prefixes=None, separators=None, split='opt', oglim=None):
    """
    Optimize the FV on the model using the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    function_vectors: torch vectors that trigger execution of a task when added to a particular layer
    edit_layer: layer index 
    n_steps: number of epochs to train affine on
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    optimizer: affine transformation optimizer
    lr_scheduler: learning rate scheduler
    weight: CE loss weight
    klweight: KL divergence loss weight
    shuffle_labels: Whether to shuffle the ICL labels or not
    filter_set: whether to only include samples the model gets correct via ICL
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    split: dataset split to train on
    oglim: limit of pairs for dataset split
    
    Returns:
    results: dict of topk accuracy on the test dataset, for both the model's n-shot, and n-shot + FV intervention, as well as the token rank of each prediction
    """

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos = False if model_config['prepend_bos'] else True 

    min_loss = np.inf
    tol = 0
    device = model.device
    wi = 0
    batch_nlls = []
    batch_sent, batch_target = [], []
    datanames = []
    lim = oglim
    for dataname in datasets:
        dataset = datasets[dataname]
        datanames.append(dataname)
        tar_data = dataset[split]
        if filter_set is None: filter_set = np.arange(len(tar_data))
        if oglim is None: lim = len(tar_data)
        if batch_nlls == []: batch_nlls = [[] for _ in range(lim)]
        for j in range(len(batch_nlls)):
            if j >= lim: 
                batch_nlls = batch_nlls[:lim]
                break
            if j not in filter_set: continue
            word_pairs = {'input':[], 'output':[]}
            word_pairs_test = tar_data[j]

            if prefixes is not None and separators is not None:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                        shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
            else:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
                
            # Get relevant parts of the Prompt
            query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
            query = query[0] if isinstance(query, list) else query

            target = target[0] if isinstance(target, list) else target
            
            sentence = [create_prompt(prompt_data)]
            batch_sent += sentence

            target_completion = "".join(sentence + [target])
            batch_target.append(target_completion)
            

            if len(batch_target) == batch_size or j == lim-1:
                inputs = tokenizer(batch_sent, return_tensors='pt', padding=True).to(device)
                nll_inputs = tokenizer(batch_target, return_tensors='pt', padding=True).to(device)
                nll_targets = nll_inputs.input_ids.clone()
                intervention_idx = []
                for b in range(len(batch_sent)):
                    target_len = len(nll_targets[b]) - len(inputs.input_ids[b])
                    if target_len > 0: nll_targets[b,:-target_len] = -100
                    intervention_idx.append(-1 - target_len)
                batch_nlls[j].append((nll_inputs, nll_targets, intervention_idx, dataname, wi))
                batch_sent, batch_target = [], []
        wi += 1
    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(range(n_steps), total=n_steps, ncols=55)
    for pi in pbar:
        tot_loss, n = 0, 0
        pbar_trial = tqdm(range(len(batch_nlls)), total=len(batch_nlls), position=1, leave=False, ncols=55)
        for pit in pbar_trial:

            kl_loss = 0 # To be explored further

            outputs = []
            pbar_batch = tqdm(range(len(batch_nlls[pit])), total=len(batch_nlls[pit]), position=2, leave=False, ncols=55)
            for pib in pbar_batch:
                nll_inputs, nll_targets, intervention_idx, dataset_name, wind = batch_nlls[pit][pib]
                
                # Perform Intervention
                intervention_fn = add_function_vector(edit_layer, fv_vectors[wind].reshape(1, model_config['resid_dim']), device, idx=intervention_idx)
                with TraceDict(model, layers=model_config['layer_hook_names'], retain_input=True, retain_output=True, edit_output=intervention_fn) as td:
                    output = model(**nll_inputs, labels=nll_targets)
                    outputs.append(output.loss)
                pbar_batch.set_description("Loss: %.5f" % (outputs[-1].item(),))
                    
            intervention_nll = (torch.mean(torch.stack(outputs)) * weight) + (kl_loss * klweight)
            trial_loss = intervention_nll.item()
            tot_loss += trial_loss
            intervention_nll.backward()
            optimizer.step()
            n += 1

            pbar_trial.set_description("Batch Loss: %.5f" % (trial_loss,))
            optimizer.zero_grad(set_to_none=True)
            del outputs

        lr_scheduler.step()
        tot_loss /= n
        if pi >= n_steps / 4:
            if tot_loss < min_loss:
                min_loss, tol = tot_loss, 0
            else: tol += 1
        gc.collect()
        torch.cuda.empty_cache()
        pbar.set_description("Epoch Loss: %.5f" % (tot_loss,))
        if tol >= 5: break
    
    return fv_vectors

def finetune_affine(datasets, fv_vectors, aff, edit_layer: int, n_steps: int,
                  model, model_config, tokenizer, optimizer, lr_scheduler,
                  weight=1.0, shuffle_labels:bool=False, filter_set=None,
                  prefixes=None, separators=None, split='opt', postpath="", splitpath=""):
    """
    Optimize the FV on the model using the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    function_vectors: torch vectors that trigger execution of a task when added to a particular layer
    aff: affine transformation to train on
    edit_layer: layer index 
    n_steps: number of epochs to train affine on
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    optimizer: affine transformation optimizer
    lr_scheduler: learning rate scheduler
    weight: CE loss weight
    filter_set: whether to only include samples the model gets correct via ICL
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    split: dataset split to train on
    postpath: path to use for posterior distribution
    splitpath: path to use for dataset split
    
    Returns:
    results: dict of topk accuracy on the test dataset, for both the model's n-shot, and n-shot + FV intervention, as well as the token rank of each prediction
    """

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos = False if model_config['prepend_bos'] else True 

    min_loss = np.inf
    tol = 0
    device = model.device
    batch_nlls = [[] for _ in range(5**2)]
    dn = len(datasets)
    lim = np.inf
    fsoft = F.log_softmax
    blank = {'input':[], 'output':[]}
    posts, splits = {}, {}
    postrows, splitrows = [], []
    if os.path.exists(postpath) and os.path.exists(splitpath): 
        pdf = pd.read_csv(postpath)
        postrows = pdf.values.tolist()
        pdf['Distribution'] = pdf.iloc[:, 2:].values.tolist()
        posts = pdf.pivot(columns='Dataset', index="QI", values="Distribution").to_dict()

        with open(splitpath, 'r') as sf:
            for row in sf:
                spdata = list(filter(None,row.rstrip().split(',')))
                spdata[1:] = [int(float(sp)) for sp in spdata[1:]]
                ais = spdata[1:]
                splits[spdata[0]] = [ais[:len(ais) // 2], ais[len(ais) // 2:]]
                splitrows.append(spdata)
    if type(fv_vectors) == list: fv_vectors = torch.cat(fv_vectors)
    for dataname, dataset in tqdm(datasets.items(), total=dn, ncols=60, desc="Dist. Setup", leave=False):
        tar_data = dataset[split]
        if dataname not in splits:
            filter_set = np.arange(len(tar_data), dtype=int)
            np.random.shuffle(filter_set)
            half = min(len(tar_data) // 2, 5)
            qis, ris = list(filter_set[:half]), list(filter_set[half:(half*2)])
            splitrows.append([dataname] + qis + ris)
            if splitpath != "":
                splitdf = pd.DataFrame(splitrows)
                splitdf.to_csv(splitpath, index=False, header=False)
        else:
            qis, ris = splits[dataname]
            half = len(qis)
        if dataname not in posts:
            posts[dataname] = {}
            for qi in tqdm(qis, total=half, ncols=40, desc=dataname, leave=False):
                # Get poster distribution from query
                if prefixes is not None and separators is not None:
                    prompt_data = word_pairs_to_prompt_data(blank, query_target_pair = tar_data[int(qi)], prepend_bos_token=prepend_bos, 
                                                            shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
                else:
                    prompt_data = word_pairs_to_prompt_data(blank, query_target_pair = tar_data[int(qi)], prepend_bos_token=prepend_bos,
                                                            shuffle_labels=shuffle_labels)
                    
                target = prompt_data['query_target']['output']
                target = target[0] if isinstance(target, list) else target
                
                sentence = [create_prompt(prompt_data)]
                target_token_id = get_answer_id(sentence[0], target, tokenizer)
                inter_probs = []

                with torch.no_grad():
                    for fv_vector in fv_vectors:
                        clean_output, ffv_intervention_output = function_vector_intervention(sentence, target = [target], edit_layer = edit_layer, 
                                                                                            function_vector = fv_vector,
                                                                                            model=model, model_config=model_config, tokenizer=tokenizer, 
                                                                                            compute_nll=False) 
                    
                        intervention_prob = torch.mean(fsoft(ffv_intervention_output.squeeze(), dim=0)[target_token_id])
                        
                        inter_probs.append(intervention_prob.detach())
                        del clean_output, ffv_intervention_output
                postrows.append([dataname, qi, *[ip.item() for ip in inter_probs]])
                posts[dataname][qi] = torch.stack(inter_probs).to(device=model.device, dtype=model.dtype)
            if postpath != "":
                postdf = pd.DataFrame(postrows, columns=["Dataset", "QI", *datasets.keys()])
                postdf.to_csv(postpath, index=False)

        perms = list(itertools.product(qis, ris))
        lim = min(lim, len(perms))
        for j in range(lim):
            qi, ri = perms[j]
            qi = int(qi)
            word_pairs, word_pairs_test = tar_data[qi:qi+1], tar_data[int(ri)]

            if prefixes is not None and separators is not None:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                        shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
            else:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos,
                                                        shuffle_labels=shuffle_labels)
                
            # Get relevant parts of the Prompt
            query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
            query = query[0] if isinstance(query, list) else query
            target = target[0] if isinstance(target, list) else target
            
            sentence = [create_prompt(prompt_data)]
            target_completion = "".join(sentence + [target])
            inputs = tokenizer(sentence, return_tensors='pt', padding=True).to(device)
            nll_inputs = tokenizer(target_completion, return_tensors='pt', padding=True).to(device)
            nll_targets = nll_inputs.input_ids.clone()
            target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze()) 
            nll_targets[:,:-target_len] = -100  # This is the accepted value to skip indices when computing loss (see nn.CrossEntropyLoss default)
            intervention_idx = -1 - target_len

            batch_nlls[j].append((nll_inputs, nll_targets, intervention_idx, dataname, qi))
        
    gc.collect()
    torch.cuda.empty_cache()
    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(range(n_steps), total=n_steps, ncols=55)
    for pi in pbar:
        tot_loss, n = 0, 0
        pbar_trial = tqdm(range(len(batch_nlls)), total=len(batch_nlls), position=1, leave=False, ncols=55)
        for pit in pbar_trial:
            outputs = []
            pbar_batch = tqdm(range(len(batch_nlls[pit])), total=len(batch_nlls[pit]), position=2, leave=False, ncols=55)
            for pib in pbar_batch:
                nll_inputs, nll_targets, intervention_idx, dataset_name, qind = batch_nlls[pit][pib]
                if not torch.is_tensor(posts[dataset_name][qind]):
                    posts[dataset_name][qind] = torch.tensor(posts[dataset_name][qind]).to(device=model.device, dtype=model.dtype)
                cfv_post = aff(F.softmax(posts[dataset_name][qind], dim=0))
                cfv = cfv_post @ fv_vectors

                # Perform Intervention
                intervention_fn = add_function_vector(edit_layer, cfv.reshape(1, model_config['resid_dim']), device, idx=intervention_idx)
                with TraceDict(model, layers=model_config['layer_hook_names'], retain_input=True, retain_output=True, edit_output=intervention_fn) as td:
                    output = model(**nll_inputs, labels=nll_targets)
                    outputs.append(output.loss)
                pbar_batch.set_description("Loss: %.5f" % (outputs[-1].item(),))
                    
            intervention_nll = (torch.mean(torch.stack(outputs)) * weight)
            trial_loss = intervention_nll.item()
            tot_loss += trial_loss
            intervention_nll.backward()
            optimizer.step()
            n += 1

            pbar_trial.set_description("Batch Loss: %.5f" % (trial_loss,))
            optimizer.zero_grad(set_to_none=True)
            del outputs
            
        lr_scheduler.step()
        tot_loss /= n
        if pi >= n_steps / 4:
            if tot_loss < min_loss:
                min_loss, tol = tot_loss, 0
            else: tol += 1
        gc.collect()
        torch.cuda.empty_cache()
        pbar.set_description("Epoch Loss: %.5f" % (tot_loss,))
        if tol >= 5: break
    
    return aff