import copy
import csv

import torch
from loguru import logger
from peft import LoraConfig, get_peft_model, IA3Config, prepare_model_for_kbit_training
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from vds_load import NeoLoader, Metric
from vds_module import MarkedDataset
from vds_shared import DEVICE, DRYRUN_SAMPLE_NUM, CFG_MODE_DRYRUN, REPORT_OUTS_DIR
from vds_util import get_attr, set_attr
from utils.dataset import load_dataset, get_context_allowed_shots
from utils.template import make_prompt


class Finetune:
    def __init__(self, cfg=None):
        self.cfg = cfg
        self.tokenizer = NeoLoader.load_tokenizer(cfg.MODEL_NAME)
        self.config, self.model, self.attrs = NeoLoader.load_model(cfg.MODEL_NAME)

        # dim_size
        if 'gpt2' in cfg.MODEL_NAME:
            # target_modules = ["c_attn", "mlp.c_proj"] if self.cfg.PEFT_ALGO == 'ia3' else ["c_attn", "c_proj"]
            # feedforward_modules = ["mlp.c_proj"] if self.cfg.PEFT_ALGO == 'ia3' else ["c_proj"]
            dim_size = self.model.config.n_embd
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.n_positions
        elif 'pythia' in cfg.MODEL_NAME:
            # target_modules = ["query_key_value", "dense_4h_to_h"] if self.cfg.PEFT_ALGO == 'ia3' else ["query_key_value", "dense"]
            # feedforward_modules = ["dense_4h_to_h"] if self.cfg.PEFT_ALGO == 'ia3' else ["dense"]
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'gemma' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'Qwen' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'llama' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        else:
            raise NotImplementedError

        # TODO ...
        if DEVICE == 'cuda' and (
                'gemma' in self.cfg.MODEL_NAME or 'Qwen' in self.cfg.MODEL_NAME or 'llama' in self.cfg.MODEL_NAME):
            self.model = prepare_model_for_kbit_training(self.model)

        # ia3
        if self.cfg.PEFT_ALGO == 'ia3':
            peft_config = IA3Config(
                # target_modules=target_modules,
                # feedforward_modules=feedforward_modules,
                fan_in_fan_out=True if 'gpt2' in cfg.MODEL_NAME else False,
                task_type="CAUSAL_LM",
            )
            # it freezes all parameters except the specified modules
            self.model = get_peft_model(self.model, peft_config)
        elif self.cfg.PEFT_ALGO == 'lora':
            # lora
            peft_config = LoraConfig(
                # target_modules=target_modules,
                fan_in_fan_out=True if 'gpt2' in cfg.MODEL_NAME else False,
                task_type="CAUSAL_LM",
            )
            # it freezes all parameters except the specified modules
            self.model = get_peft_model(self.model, peft_config)
        else:
            raise NotImplementedError

        self.print_trainable_parameters()

        # ...
        self.identifier = f'{self.cfg.DATA_CODE}.{self.cfg.MODEL_CODE}'

        # cluster_num
        self.train_inputs, self.train_outputs, self.test_inputs, self.test_outputs = self.load_corpus()

        lm_head_matrix = self.get_lm_head_matrix()
        self.matrix = lm_head_matrix.transpose(0, 1).to(DEVICE)

    def hack_model(self):
        if self.cfg.PEFT_ALGO == 'ia3':
            peft_config = IA3Config(
                fan_in_fan_out=True if 'gpt2' in self.cfg.MODEL_NAME else False,
                task_type="CAUSAL_LM",
            )
            # it freezes all parameters except the specified modules
            self.model = get_peft_model(self.model, peft_config)
        elif self.cfg.PEFT_ALGO == 'lora':
            peft_config = LoraConfig(
                fan_in_fan_out=True if 'gpt2' in self.cfg.MODEL_NAME else False,
                task_type="CAUSAL_LM",
            )
            # it freezes all parameters except the specified modules
            self.model = get_peft_model(self.model, peft_config)
        else:
            raise NotImplementedError

    def print_trainable_parameters(self):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in (self.model.named_parameters()):
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        trainable_ratio = 100 * trainable_params / all_param
        logger.info(f"{trainable_params=} || {all_param=} || {trainable_ratio} %")

    def get_lm_head_matrix(self):
        lm_head_matrix = get_attr(self.model, self.attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()
        return lm_head_matrix

    def load_corpus(self):
        # prepare dataset
        train_data, dev_data = load_dataset(dataset=self.cfg.DATA_CODE)
        n_demo_shot = get_context_allowed_shots(dataset=self.cfg.DATA_CODE, context_len=self.max_context_len)

        # vocab_labels
        self.vocab_labels = list()
        for idx, label_verb in enumerate(train_data.id2verb):
            label = self.tokenizer.encode(' ' + label_verb)[-1]
            self.vocab_labels.append(label)
        # vocab_label2id
        self.vocab_label2id = dict()
        for label_id, vocab_label in enumerate(self.vocab_labels):
            self.vocab_label2id[vocab_label] = label_id

        copy_data = copy.deepcopy(train_data)
        copy_data.subsamplebyshot(n_demo_shot)
        prompt_prefix = make_prompt(copy_data, self.cfg.DATA_CODE, mode='train')

        # train
        train_comments = list()
        train_labels = list()
        label2id = train_data.label2id
        for ins in train_data.data:
            if ins in copy_data.data:
                # filtered_data = [datum for datum in copy_data.data if datum != ins]
                # filtered_prompt_prefix = make_prompt(copy_data, self.cfg.DATA_CODE, mode='compose', indices=filtered_data)
                # comment = filtered_prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
                comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            else:
                # comment = prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
                comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            train_comments.append(comment)
            label_id = label2id[ins['label']]
            label = self.vocab_labels[label_id]
            train_labels.append(label)

        train_data.subsamplebyshot(n_demo_shot)
        prompt_prefix = make_prompt(train_data, self.cfg.DATA_CODE, mode='train')

        # test
        test_comments = list()
        test_labels = list()
        label2id = dev_data.label2id
        for ins in dev_data.data:
            # comment = prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            test_comments.append(comment)
            label_id = label2id[ins['label']]
            label = self.vocab_labels[label_id]
            test_labels.append(label)

        if CFG_MODE_DRYRUN:
            dry_factor_train = 4
            dry_factor_test = 1
            train_comments = train_comments[:DRYRUN_SAMPLE_NUM * dry_factor_train]
            train_labels = train_labels[:DRYRUN_SAMPLE_NUM * dry_factor_train]
            test_comments = test_comments[:DRYRUN_SAMPLE_NUM * dry_factor_test]
            test_labels = test_labels[:DRYRUN_SAMPLE_NUM * dry_factor_test]

        return train_comments, train_labels, test_comments, test_labels

    def preprocess(self, comments):
        iis = list()
        ams = list()
        for comment in comments:
            inputs = self.tokenizer.encode_plus(comment, return_tensors="pt", padding=True)  # .to(device=DEVICE)
            # inputs = self.tokenizer.encode_plus(comment, return_tensors="pt", padding="max_length")  # .to(device=DEVICE)
            if inputs['input_ids'].shape[1] > self.max_context_len:
                inputs['input_ids'] = inputs['input_ids'][:, -self.max_context_len:]
                inputs['attention_mask'] = inputs['attention_mask'][:, -self.max_context_len:]
            iis.append(inputs['input_ids'].squeeze(0))
            ams.append(inputs['attention_mask'].squeeze(0))
        return iis, ams

    def parse_response(self, gen_logits):
        # logger.critical(f'{gen_logits.shape=}')
        gen_prob = torch.softmax(gen_logits, dim=-1)
        prob_per_cls = []
        for vocab_label in self.vocab_labels:
            if self.cfg.PEFT_VAI:
                vocab_label_id = self.vocab_label2id[vocab_label]
                prob_per_cls.append(gen_prob[:, vocab_label_id])
            else:
                prob_per_cls.append(gen_prob[:, vocab_label])
        # we filter the global_labels, to keep consist with the local_labels
        label_id = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist()
        pred = self.vocab_labels[label_id]
        # logger.debug(f'{pred=}')
        return pred

    def lm_project(self, representation):
        # logger.debug(f'{labels.shape=}')  # labels.shape=torch.Size([1])
        logits = torch.matmul(representation, self.matrix).squeeze(1)
        # logger.debug(f'{logits.shape=}')  # scores.shape=torch.Size([1, 50257])
        return logits

    def train(self, data_loader, optimizer):
        # self.model = self.model.to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for batch in tqdm(data_loader, leave=False):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = self.model(input_ids=input_ids, output_hidden_states=True, return_dict=True)
            # logger.debug(f'{len(outputs.hidden_states)=}')
            hidden_states = outputs.hidden_states

            hidden_state = hidden_states[-1]
            # logger.debug(f'{hidden_state.shape=}')
            embeds_hat = hidden_state[:, -1, :].unsqueeze(0)
            logits = self.lm_project(embeds_hat)
            # logger.debug(f'{logits.shape=}')
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

    def predict(self, data_loader):
        labels = list()
        # self.model = self.model.to(DEVICE)
        self.model.eval()
        for batch in tqdm(data_loader, leave=False):
            # inputs = inputs.to(device)
            # targets = targets.to(device)
            input_ids = batch["input_ids"].to(DEVICE)
            outputs = self.model.generate(inputs=input_ids, max_new_tokens=1, output_hidden_states=True, return_dict_in_generate=True)
            # confusing on the nested indices ...
            hidden_states = outputs.hidden_states[0]
            # ...
            hidden_state = hidden_states[-1]
            embeds_hat = hidden_state[:, -1, :].unsqueeze(0)
            logits = self.lm_project(embeds_hat)
            label = self.parse_response(logits)
            labels.append(label)
        return labels

    def optimize_model(self):
        train_comments = self.train_inputs
        train_labels = self.train_outputs
        test_comments = self.test_inputs
        test_labels = self.test_outputs

        if self.cfg.PEFT_VAI:
            _train_labels = [self.vocab_label2id[label] for label in train_labels]
            _test_labels = [self.vocab_label2id[label] for label in test_labels]
            train_dataset = MarkedDataset(train_comments, _train_labels, self.tokenizer, self.max_context_len)
            test_dataset = MarkedDataset(test_comments, _test_labels, self.tokenizer, self.max_context_len)
        else:
            train_dataset = MarkedDataset(train_comments, train_labels, self.tokenizer, self.max_context_len)
            test_dataset = MarkedDataset(test_comments, test_labels, self.tokenizer, self.max_context_len)

        # TODO ...
        train_loader = DataLoader(train_dataset, batch_size=1)
        test_loader = DataLoader(test_dataset, batch_size=1)

        # ...
        self.hack_model()
        optimizer = optim.AdamW(self.model.parameters())
        for _ in tqdm(range(self.cfg.PEFT_EPOCH_NUM)):
            self.train(train_loader, optimizer)

        # Evaluation
        post_labels = self.predict(test_loader)
        post_peft_acc = Metric.same_accuracy(post_labels, test_labels)
        # learner.eval()
        # post_logits = trainer.predict(learner, dataloaders=test_loader)
        # post_labels = [self.parse_response(logits) for logits in post_logits]
        # post_peft_acc = Metric.same_accuracy(post_labels, test_labels)

        # ...
        logger.warning(f'{post_peft_acc=}')
        Metric.general_gen_scoring(post_labels, test_labels)

        # pre_gens = [self.tokenizer.decode(label) for label in pre_labels]
        # post_gens = [self.tokenizer.decode(label) for label in post_labels]
        # golden_gens = [self.tokenizer.decode(label) for label in test_labels]
        # Metric.contrast_scoring(pre_gens, post_gens, golden_gens)
        # Metric.contrast_scoring2(pre_gens, post_gens, golden_gens)

        # # dump generalizations
        # well_dump_gens('_'.join([args.data, args.exp]), pre_gens, post_gens, golden_gens)

        REPORT_OUTS_DIR.mkdir(parents=True, exist_ok=True)
        save_results_file = REPORT_OUTS_DIR / f'glance_{self.cfg.PEFT_ALGO}.csv'
        csv_exists = save_results_file.exists()
        with open(save_results_file, 'a+', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            if not csv_exists:
                csvwriter.writerow(['llm', 'dataset', f'acc_{self.cfg.PEFT_ALGO}'])
            csvwriter.writerow([self.cfg.MODEL_CODE, self.cfg.DATA_CODE, post_peft_acc])
