import sys
import os
from disentangle.dataset import CFDataset, DisDataset
from disentangle.utils import Accumulator, Saver
from baselines.rome import repr_tools

import torch.utils
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import torch

from tqdm import tqdm
import hydra
from omegaconf import DictConfig
import random
import json

from baselines.rome.compute_v import find_fact_lookup_idx
from util import nethook

@hydra.main(version_base=None, config_path=".", config_name="config_l")
def train(config: DictConfig):
    # set seed
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # load dataset
    with open(config['data_path'], 'r') as f:
        raw = json.load(f)
        random.shuffle(raw)

    def generate_data(raw, cnt):
        data, idx = [], 0
        while len(data) < cnt:
            subjects = [raw[idx]['subject']['str']] + raw[idx]['subject'].get('aliases', [])
            r1, r2 = random.sample(raw[idx]['relations'], 2)
            d1, d2, d3 = random.sample(r1['data'], 1)[0], random.sample(r2['data'], 1)[0], random.sample(r1['data'], 1)[0]
            data.append({
                'subject': random.sample(subjects, 1)[0],
                'prompt': d1['prompt'],
                'answer': d1['answer'],
                'neighbor_prompt': d2['prompt'],
                'neighbor_answer': d2['answer'],
                'para_prompt': d3['prompt']
            })
            idx = (idx + 1) % len(raw)
        return data

    val_split = int(len(raw) * config['val_ratio'])
    test_data = generate_data(raw[:val_split], config['val_cnt'])
    train_data = generate_data(raw[val_split:], config['train_cnt'])
    train_loader = DataLoader(DisDataset(train_data), batch_size=config['batch_size'], shuffle=True, drop_last=True)
    test_loader = DataLoader(DisDataset(test_data), batch_size=config['batch_size'], drop_last=True)

    # load model
    print('load model from', config['model_path'])
    model = AutoModelForCausalLM.from_pretrained(config['model_path']).cuda()
    tok = AutoTokenizer.from_pretrained(config['model_path'])
    tok.pad_token = tok.eos_token
    nethook.set_requires_grad(False, model)

    _, ae_model = Saver.init_ae_model(config)
    ae_model = ae_model.cuda()
    layers = {
        'edit': [config['module_tmp'].format(l) for l in range(config['from_layer'], config['to_layer']+1)],
        'relation': config['layer_tmp'].format(config['relation_layer']),
        'relation_tmp': config['layer_tmp'],
        'relation_layer': config['relation_layer'],
    }
    alpha = {'loss1': config['alpha1'], 'loss2': config['alpha2']}

    # train
    dec_p, else_p = [], []
    for n, p in ae_model.named_parameters():
        (dec_p if 'dec' in n else else_p).append(p)
    optimizer = torch.optim.AdamW([
        {'params': dec_p, 'lr': config['dec_lr']},
        {'params': else_p, 'lr': config['lr']},
    ])
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config['epoch'])
    print_steps = (int(len(train_loader) * config['print_steps'])) if config['print_steps'] < 1 else (config['print_steps'])
    saver = Saver(config)

    def print_acc(acc: Accumulator):
        loss1, loss2, orth_loss, recon_loss = acc.output()
        pbar.set_postfix({
            'L1': loss1,
            'L2': loss2,
            'LO': orth_loss,
            'LR': recon_loss,
        })

    for ep in range(config['epoch']):
        train_acc = Accumulator(4)
        ae_model = ae_model.train()
        with tqdm(train_loader, desc=f"Train {ep+1}/{config['epoch']}", ncols=100) as pbar:
            for i, data in enumerate(pbar):
                loss1, loss2, orth_loss, recon_loss = train_a_step(model, tok, layers, ae_model, data, alpha)
                loss = loss1 + loss2 + config['orth_alpha'] * orth_loss + config['recon_alpha'] * recon_loss
                # loss = loss1 * config['alpha1'] + loss2 * config['alpha2'] + orth_loss * config['orth_alpha'] + recon_loss * config['recon_alpha']
                train_acc.add(loss1.item(), loss2.item(), orth_loss.item(), recon_loss.item())

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if (i + 1) % print_steps == 0:
                    print_acc(train_acc)

        test_acc = Accumulator(4)
        ae_model = ae_model.eval()
        with tqdm(test_loader, desc=f"Evaluate", ncols=100) as pbar:
            for i, data in enumerate(pbar):
                with torch.no_grad():
                    loss1, loss2, orth_loss, recon_loss = train_a_step(model, tok, layers, ae_model, data, alpha)
                loss = loss1 + loss2 + config['orth_alpha'] * orth_loss + config['recon_alpha'] * recon_loss
                # loss = loss1 * config['alpha1'] + loss2 * config['alpha2'] + orth_loss * config['orth_alpha'] + recon_loss * config['recon_alpha']
                test_acc.add(loss1.item(), loss2.item(), orth_loss.item(), recon_loss.item())
        loss1, loss2, orth_loss, recon_loss = test_acc.output()
        eval_str_loss = f'loss1: {loss1}, loss2: {loss2}, orth_loss: {orth_loss}, recon_loss: {recon_loss}'
        print(eval_str_loss)
        saver.save(ae_model, ep, eval_str_loss)

        scheduler.step()

def process_inputs(tok, prompts, answers, subjects):
    ans_ids = [tok(ans, return_tensors="pt").to("cuda")["input_ids"][0] for ans in answers]
    con_questions = [prompt.format(subject) + tok.decode(ans[:-1]) for prompt, subject, ans in zip(prompts, subjects, ans_ids)]
    input_tok = tok(con_questions, return_tensors="pt", padding=True).to("cuda")

    targets = torch.tensor(-100, device="cuda").repeat(len(con_questions), *input_tok["input_ids"].shape[1:])
    for i in range(len(con_questions)):
        ex_len = input_tok["attention_mask"][i].sum()    
        targets[i, ex_len - len(ans_ids[i]) : ex_len] = ans_ids[i]
    
    lookup_idx = [
        find_fact_lookup_idx(prompt, subject, tok, 'subject_last', verbose=False)
        for prompt, subject in zip(prompts, subjects)
    ]

    return input_tok, lookup_idx, targets


def train_a_step(model, tok, layers, ae_model, data, alpha):
    dis_prompts, subjects, answers, neighbor_prompts, neighbor_answers, para_prompts = [list(item) for item in data]
    bs = len(dis_prompts)

    input0, lookup0, _ = process_inputs(tok, dis_prompts, answers, subjects)
    # input1, lookup1, target1 = process_inputs(tok, para_prompts, answers, subjects)
    input1, lookup1, target1 = process_inputs(tok, dis_prompts, answers, subjects)
    input2, lookup2, target2 = process_inputs(tok, neighbor_prompts, neighbor_answers, subjects)

    relation_args = dict(
        model=model,
        tok=tok,
        layer=layers['relation_layer'],
        module_template=layers['relation_tmp'],
        track="out",
    )
    questions = [prompt.format(subject) for prompt, subject in zip(dis_prompts, subjects)]
    relation_idx = [[(s-1).item()] for s in tok(questions, return_tensors="pt", padding=True)['attention_mask'].sum(-1)]
    relation_repr = repr_tools.get_reprs_at_idxs(
        contexts=questions,
        idxs=relation_idx,
        **relation_args,
    )

    def forward0(repr, _):
        nonlocal s1, s2, s3, repr_init
        repr_init = (repr[0] if isinstance(repr, tuple) else repr)[range(bs), lookup0]
        s1, s2, s3 = ae_model(repr_init, relation_repr)
        return repr

    def forward1(repr, _):
        (repr[0] if isinstance(repr, tuple) else repr)[range(bs), lookup1] = s1
        return repr

    def forward2(repr, _):
        (repr[0] if isinstance(repr, tuple) else repr)[range(bs), lookup2] = s2
        return repr

    LOSS1, LOSS2, INFO_LOSS, RECON_LOSS = [], [], [], []
    avg = lambda x: sum(x) / len(x)
    for edit_layer in layers['edit']:
        s1, s2, s3, repr_init = None, None, None, None

        with nethook.Trace(module=model, layer=edit_layer, retain_input=False, retain_output=True, edit_output=forward0):
            model(**input0).logits
        
        with torch.no_grad():
            pre_logits1 = model(**input1).logits
            nll_loss_pre1 = compute_nll_loss(pre_logits1, target1)
            pre_logits2 = model(**input2).logits
            nll_loss_pre2 = compute_nll_loss(pre_logits2, target2)

        with nethook.Trace(module=model, layer=edit_layer, retain_input=False, retain_output=True, edit_output=forward1):
            logits1 = model(**input1).logits
            nll_loss1 = compute_nll_loss(logits1, target1)
            idx1 = nll_loss1 > nll_loss_pre1 * alpha['loss1']
            loss1 = (nll_loss1[idx1] if idx1.any() else nll_loss1.detach()).mean()
            # loss1 = nll_loss1.mean()

        with nethook.Trace(module=model, layer=edit_layer, retain_input=False, retain_output=True, edit_output=forward2):
            logits2 = model(**input2).logits
            nll_loss2 = compute_nll_loss(logits2, target2)
            idx2 = nll_loss2 > nll_loss_pre2 * alpha['loss2']
            loss2 = (nll_loss2[idx2] if idx2.any() else nll_loss2.detach()).mean()
            # loss2 = nll_loss2.mean()

        recon_loss = torch.nn.functional.mse_loss(s3, repr_init)
        info_loss = compute_info_loss_1(s1, repr_init, s2) + compute_info_loss_1(s2, repr_init, s1)

        LOSS1.append(loss1)
        LOSS2.append(loss2)
        INFO_LOSS.append(recon_loss)
        RECON_LOSS.append(info_loss)
    
    return avg(LOSS1), avg(LOSS2), avg(INFO_LOSS), avg(RECON_LOSS)

def compute_nll_loss(logits, targets):
    probs = torch.log_softmax(logits, dim=2)
    loss = torch.gather(probs, 2, torch.where(targets != -100, targets, 0).unsqueeze(2)).squeeze(2)
    mask = (targets != -100).float()
    nll_loss = -(loss * mask).sum(1) / mask.sum(1)
    return nll_loss

INFO_IDX = None
def compute_info_loss_1(d, s, end):
    global INFO_IDX
    if INFO_IDX is None:
        B = d.shape[0]
        INFO_IDX = torch.arange(B).repeat(B)[:-1].view(B-1, B+1)[:, 1:].reshape(B, -1)

    # return CTR(d, s, torch.concat([d[INFO_IDX], end.unsqueeze(1)], dim=1), tau=0.1)
    return CTR(d, s, torch.concat([s[INFO_IDX], end.unsqueeze(1)], dim=1), tau=0.1)
    # return CTR(d, s, end, tau=0.1)

def CTR(s, s_pos, s_neg, tau=1):
    s, s_pos, s_neg = [item.unsqueeze(1) if item.ndim == 2 else item for item in [s, s_pos, s_neg]]
    sim = torch.nn.functional.cosine_similarity

    a = (sim(s, s_pos, dim=-1) / tau).exp().sum(-1)
    b = (sim(s, torch.cat([s_pos, s_neg], dim=1), dim=-1) / tau).exp().sum(-1)
    loss = -(a / b).log().mean()

    return loss


if __name__ == "__main__":
    train()
