import json
import torch
from tqdm import tqdm
import numpy as np
from transformers import T5ForSequenceClassification, T5TokenizerFast, RobertaTokenizerFast
# from peft import PrefixTuningConfig, get_peft_model, TaskType
from peft import LoraModel, LoraConfig
import pickle
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, balanced_accuracy_score


# MODEL_NAME = "google/t5-v1_1-small"

class CodeGenDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        # ids: torch.tensor[int],
        tokenized_requests: torch.tensor,
        att_masks: torch.tensor,
        results: torch.tensor
    ):

        # self.ids = ids
        self.requests = tokenized_requests
        self.att_masks = att_masks
        self.results = results
        sum_res = 0
        for r in results:
            sum_res += int(r)
        print(f"Number of correct solutions in the dataset: {sum_res / len(results)}")

    def __len__(self):
        return len(self.requests)

    def __getitem__(self, idx):
        return (
            self.requests[idx],
            self.att_masks[idx],
            torch.tensor([int(self.results[idx])])
        )


# CKPT_NAME = f'{MODEL_NAME}_class_mbpp.pth'
DATASET_PREFIX = '_he'
MODEL_NAME = 'Salesforce/codet5-base'
N_EPOCHS = 100
LR = 3e-5
DEVICE = 'cuda'
EVALUATED_MODEL = 'qwencoder'

# with open('/app/results_qwen27_7b-mbpp_evaluated_enr.json', 'r') as f:
#     dataset = json.load(f)


# ids = []
# tasks = []
# code = []
# asserts = []
# results = []
# for ident, res in dataset.items():
#     # ids.append(res['task_id'])
#     tasks.append(res['task'])
#     code.append(res['code'])
#     asserts.append(res['tests'])
#     results.append(res['result'] == 'passed')


# tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
# requests = [f"{text}\n{asserts}\n{code}" for text, assets, code in zip(tasks, asserts, code)]
# requests = [f"{text}\n{code}" for text, assets, code in zip(tasks, asserts, code)]
# task_prefix = "You are an experienced coder. Check whether the following solution for the provided tasks is correct and can pas the asserts: "
# task_prefix = "You are an experienced coder. Check whether the following solution for the provided tasks is correct: "

for DATASET_PREFIX in ['_mbpp', '_he']:

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    with open(f'/app/results/splits{DATASET_PREFIX}.pickle', 'rb') as f:
        splits = pickle.load(f)

    with open(f'/app/results/{EVALUATED_MODEL}/Y{DATASET_PREFIX}.pickle', 'rb') as f:
        results = pickle.load(f)

    with open(f'/app/results/{EVALUATED_MODEL}/code{DATASET_PREFIX}.pickle', 'rb') as f:
        code = pickle.load(f)

    tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
    tokenized_requests = tokenizer(
        # [task_prefix + r for r in requests],
        code,
        padding="longest",
        max_length=8192,
        truncation=True,
        return_tensors="pt",
    )

    metric_callbacks = {
        'Accuracy': lambda pred, targ: accuracy_score(targ, (pred > 0.5).astype(np.int32)),
        'Balanced accuracy': lambda pred, targ: balanced_accuracy_score(targ, (pred > 0.5).astype(np.int32)),
        'F1 score': lambda pred, targ: f1_score(targ, (pred > 0.5).astype(np.int32)),
        'ROC-AUC score': lambda pred, targ: roc_auc_score(targ, pred)
    }

    metric_values = [{k: [] for k in metric_callbacks.keys()} for _ in splits]

    for split_ind, (train, test) in enumerate(splits):

        # train_quant = 0.9
        # train_size = int(len(tasks) * train_quant)
        # permut = torch.randperm(len(tasks))

        print('=' * 100)
        print(f"Running training for the {split_ind+1}'th split")
        print('=' * 100)

        train_req, test_req = (
            tokenized_requests.input_ids[train],
            tokenized_requests.input_ids[test]
        )
        train_att, test_att = (
            tokenized_requests.attention_mask[train],
            tokenized_requests.attention_mask[test]
        )
        train_res, test_res = (
            torch.tensor(results)[train],
            torch.tensor(results)[test]
        )

        model = T5ForSequenceClassification.from_pretrained(
            MODEL_NAME,
            num_labels=1,
            torch_dtype=torch.bfloat16
        ).to(DEVICE)

        for p in model.transformer.parameters():
            p.requires_grad = False

        # config = LoraConfig(
        #     task_type="SEQ_CLS",
        #     r=8,
        #     lora_alpha=32,
        #     target_modules=["q", "v"],
        #     lora_dropout=0.01,
        # )
        # model = LoraModel(model, config, 'default')

        # peft_config = PrefixTuningConfig(
        #     task_type=TaskType.SEQ_CLS,
        #     inference_mode=False,
        #     num_virtual_tokens=20
        # )
        # model = get_peft_model(model=model, peft_config=peft_config)
        print(model)
        # model.print_trainable_parameters()

        # parameters = []
        # for n, m in model.named_modules():
        #     if n.split('.')[-1].startswith('lora'):
        #         parameters += list(m.parameters())

        # parameters = parameters + list(model.classification_head.parameters())
        parameters = list(model.classification_head.parameters())
        optimizer = torch.optim.AdamW(parameters, lr=LR)
        loss_func = torch.nn.BCEWithLogitsLoss()

        train_ds = CodeGenDataset(
            tokenized_requests=train_req,
            att_masks=train_att,
            results=train_res
        )
        test_ds = CodeGenDataset(
            tokenized_requests=test_req,
            att_masks=test_att,
            results=test_res
        )
        train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
        test_dl = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=True)

        best_acc = 0

        for epoch in range(N_EPOCHS):

            iterator = tqdm(enumerate(train_dl), total=len(train_dl))
            model.train()
            train_acc = 0
            for i, (inp_ids, att_mask, target) in iterator:

                inp_ids = inp_ids.to(DEVICE)
                att_mask = att_mask.to(DEVICE)
                target = target.to(DEVICE).view(-1)

                output = model(input_ids=inp_ids, attention_mask=att_mask).logits
                loss = loss_func(output.view(-1).float(), target.float())

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                predicts = torch.sigmoid(output) > 0.5
                acc = (predicts.view(-1) == target).float().mean().item()
                train_acc = (train_acc * i + acc) / (i + 1)
                iterator.set_description(
                    f'Loss: {loss.item():.3f} Acc: {train_acc:.3f}'
                )

            model.eval()
            with torch.no_grad():

                targets = np.array([])
                predicts = np.array([])
                iterator = tqdm(test_dl)
                for inp_ids, att_mask, target in iterator:

                    inp_ids = inp_ids.to(DEVICE)
                    att_mask = att_mask.to(DEVICE)
                    target = target.to(DEVICE).view(-1)

                    output = model(input_ids=inp_ids, attention_mask=att_mask).logits

                    targets = np.concatenate((targets, target.cpu().numpy().ravel()), axis=0)
                    predicts = np.concatenate((predicts, torch.sigmoid(output).float().cpu().numpy().ravel()), axis=0)

                print(f'Epoch: {epoch}')
                for metric_name, metric_cb in metric_callbacks.items():
                    metric_value = float(metric_cb(predicts, targets))
                    metric_values[split_ind][metric_name].append(metric_value)
                    print(f'{metric_name}: {round(metric_value, 4)}')

                with open(f'/app/results/{EVALUATED_MODEL}/t5_coder_metrics{DATASET_PREFIX}.json', 'w') as f:
                    json.dump(metric_values, f)

        print('=' * 100)
        print(f"Split {split_ind+1} evaluated!")
        print('=' * 100)
