import json
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from transformers import T5ForSequenceClassification, RobertaTokenizerFast, T5EncoderModel
from peft import LoraModel, LoraConfig
import pickle
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, balanced_accuracy_score


# MODEL_NAME = "google/t5-v1_1-small"

# class CodeGenDataset(torch.utils.data.Dataset):

#     def __init__(
#         self,
#         # ids: torch.tensor[int],
#         tokenized_requests: torch.tensor,
#         att_masks: torch.tensor,
#         results: torch.tensor
#     ):

#         # self.ids = ids
#         self.requests = tokenized_requests
#         self.att_masks = att_masks
#         self.results = results
#         sum_res = 0
#         for r in results:
#             sum_res += int(r)
#         print(f"Number of correct solutions in the dataset: {sum_res / len(results)}")

#     def __len__(self):
#         return len(self.requests)

#     def __getitem__(self, idx):
#         return (
#             self.requests[idx],
#             self.att_masks[idx],
#             torch.tensor([int(self.results[idx])])
#         )

class CodeGenDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        # ids: torch.tensor[int],
        tokenized_requests: torch.tensor,
        # att_masks: torch.tensor,
        results: torch.tensor
    ):

        # self.ids = ids
        self.requests = tokenized_requests
        # self.att_masks = att_masks
        self.results = results
        sum_res = 0
        for r in results:
            sum_res += int(r)
        print(f"Number of correct solutions in the dataset: {sum_res / len(results)}")

    def __len__(self):
        return len(self.requests)

    def __getitem__(self, idx):
        return (
            self.requests[idx],
            # self.att_masks[idx],
            torch.tensor([int(self.results[idx])])
        )


MODEL_NAME = 'Salesforce/codet5-base'
N_EPOCHS = 100
LR = 3e-5
BATCH_SIZE = 32
DEVICE = 'cuda'

for EVALUATED_MODEL in ['qwencoder', 'starcoder2', 'llama', 'magicoder', 'deepseekcoder']:
    for DATASET_PREFIX in [['_mbpp', '_he'], ['_he', '_mbpp']]:

        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)

        TRAIN_DS_PREFIX, TEST_DS_PREFIX = DATASET_PREFIX

        with open(f'/app/results/{EVALUATED_MODEL}/Y{TRAIN_DS_PREFIX}.pickle', 'rb') as f:
            train_results = pickle.load(f)

        with open(f'/app/results/{EVALUATED_MODEL}/code{TRAIN_DS_PREFIX}.pickle', 'rb') as f:
            train_code = pickle.load(f)

        with open(f'/app/results/{EVALUATED_MODEL}/Y{TEST_DS_PREFIX}.pickle', 'rb') as f:
            test_results = pickle.load(f)

        with open(f'/app/results/{EVALUATED_MODEL}/code{TEST_DS_PREFIX}.pickle', 'rb') as f:
            test_code = pickle.load(f)

        tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
        train_len = len(train_code)
        tokenized_requests = tokenizer(
            train_code + test_code,
            padding="longest",
            max_length=8192,
            truncation=True,
            return_tensors="pt",
        )

        # model = T5ForSequenceClassification.from_pretrained(
        #     MODEL_NAME,
        #     num_labels=1,
        #     torch_dtype=torch.bfloat16
        # ).to(DEVICE)

        model = T5EncoderModel.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16
        ).to(DEVICE)
        for p in model.parameters():
            p.requires_grad = False
        model.eval()

        classification_head = nn.Sequential(
            nn.Linear(768, 768),
            nn.GELU(),
            nn.Linear(768, 1)
        ).type(torch.bfloat16).to(DEVICE)

        print(model)

        all_res = []
        for i in range(int(np.ceil(len(tokenized_requests.input_ids) / BATCH_SIZE))):
            inp_batch = tokenized_requests.input_ids[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
            att_mask_batch = tokenized_requests.attention_mask[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
            res = model(
                input_ids=inp_batch.to(DEVICE),
                attention_mask=att_mask_batch.to(DEVICE)
            ).last_hidden_state[:, -1, :].cpu()
            all_res.append(res)
        all_res = torch.cat(all_res, axis=0)
        print(all_res.shape)

        metric_callbacks = {
            'Accuracy': lambda pred, targ: accuracy_score(targ, (pred > 0.5).astype(np.int32)),
            'Balanced accuracy': lambda pred, targ: balanced_accuracy_score(targ, (pred > 0.5).astype(np.int32)),
            'F1 score': lambda pred, targ: f1_score(targ, (pred > 0.5).astype(np.int32)),
            'ROC-AUC score': lambda pred, targ: roc_auc_score(targ, pred)
        }

        metric_values = {k: [] for k in metric_callbacks.keys()}

        train_req, test_req = (
            all_res[:train_len],
            all_res[train_len:]
        )
        train_res, test_res = (
            torch.tensor(train_results),
            torch.tensor(test_results)
        )
        parameters = list(classification_head.parameters())
        optimizer = torch.optim.AdamW(parameters, lr=LR)
        loss_func = torch.nn.BCEWithLogitsLoss()

        train_ds = CodeGenDataset(
            tokenized_requests=train_req,
            results=train_res
        )
        test_ds = CodeGenDataset(
            tokenized_requests=test_req,
            results=test_res
        )
        train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
        test_dl = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)

        best_acc = 0

        for epoch in range(N_EPOCHS):

            iterator = tqdm(enumerate(train_dl), total=len(train_dl))
            classification_head.train()
            train_acc = 0
            for i, (inp, target) in iterator:

                inp = inp.to(DEVICE)
                target = target.to(DEVICE).view(-1)

                output = classification_head(inp)
                loss = loss_func(output.view(-1).float(), target.float())

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                predicts = torch.sigmoid(output) > 0.5
                acc = (predicts.view(-1) == target).float().mean().item()
                train_acc = (train_acc * i + acc) / (i + 1)
                iterator.set_description(
                    f'Loss: {loss.item():.3f} Acc: {train_acc:.3f}'
                )

            classification_head.eval()
            with torch.no_grad():

                targets = np.array([])
                predicts = np.array([])
                iterator = tqdm(test_dl)
                for inp, target in iterator:

                    inp = inp.to(DEVICE)
                    target = target.to(DEVICE).view(-1)

                    output = classification_head(inp)

                    targets = np.concatenate((targets, target.cpu().numpy().ravel()), axis=0)
                    predicts = np.concatenate((predicts, torch.sigmoid(output).float().cpu().numpy().ravel()), axis=0)

                print(f'Epoch: {epoch}')
                for metric_name, metric_cb in metric_callbacks.items():
                    metric_value = float(metric_cb(predicts, targets))
                    metric_values[metric_name].append(metric_value)
                    print(f'{metric_name}: {round(metric_value, 4)}')

                with open(f'/app/results/{EVALUATED_MODEL}/t5_transfer{TRAIN_DS_PREFIX}{TEST_DS_PREFIX}.json', 'w') as f:
                    json.dump(metric_values, f)

    print('=' * 100)
    print(f"Model {EVALUATED_MODEL} evaluated!")
    print('=' * 100)
