import os
import time
import torch
import pickle
import argparse
import random
import numpy as np
import pandas as pd

from csv import writer
from tqdm import tqdm
from copy import deepcopy
from transformers import AutoTokenizer
from transformers import GPTJForCausalLM
from dataset_utils.fever import FEVER
from laser.LaserWrapper import LaserWrapper
from study_utils.log_utils import Logger
from study_utils.metric_utils import Metrics, DatasetMetrics, ContextAnswerLogProb
from study_utils.time_utils import elapsed_from_str, Progress

class Results:

    def __init__(self, val_acc, val_logloss, test_acc, test_logloss):
        self.val_acc = val_acc
        self.val_logloss = val_logloss
        self.test_acc = test_acc
        self.test_logloss = test_logloss

    def to_str(self, only_test=False):
        if only_test:
            return f"Test acc {self.test_acc:.3f}, Test logloss {self.test_logloss:.3f}"
        else:
            return f"Validation acc {self.val_acc:.3f}, Validation logloss {self.val_logloss:.3f}, " \
                   f"Test acc {self.test_acc:.3f}, Test logloss {self.test_logloss:.3f}"
        
class GPTJExperiment:

    def __init__(self, save_dir, logger):
        self.save_dir = save_dir
        self.logger = logger

        # Object to measure progress (as in time taken and time left to complete)
        self.progress = Progress(logger=logger)

        # Object to compute metrics. We set whether we should consider whitespace and lowercase when evaluating
        self.case_sensitive = False
        self.strip = True
        self.metrics = Metrics(case_sensitive=self.case_sensitive, strip=self.strip)

        # Object to aggregate performance over a dataset
        self.dataset_metric = DatasetMetrics(logger=logger)

        # Device for the experiment
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def intervene(self, model, tokenizer, dataset, args, llm_name, lnums, k_clusters=1):

        dataset_size = len(dataset)
        self.logger.log(f"Starting a new intervention with rate {args.rate}. "
                        f"Dataset size {dataset_size}. Batch size {args.batch_size}")

        time_edit_start = time.time()
        if type(lnums) == list:
            model_edit, optimal_indices = LaserWrapper.get_edited_model_multiple_layers(model=model,
                                                    lname=args.lname,
                                                    lnums=lnums,
                                                    rate=args.rate,
                                                    intervention=args.intervention,
                                                    logger=logger,
                                                    in_place=True,
                                                    num_clusters=args.num_clusters,
                                                    group_via=args.group_via,
                                                    shuffle=args.shuffle,
                                                    k_clusters=k_clusters)
        else:
            model_edit, optimal_indices = LaserWrapper.get_edited_model(model=model,
                                                   lname=args.lname,
                                                   lnum=args.lnum,
                                                   rate=args.rate,
                                                   intervention=args.intervention,
                                                   logger=logger,
                                                   in_place=True,
                                                   num_clusters=args.num_clusters,
                                                   group_via=args.group_via,
                                                   shuffle=args.shuffle,
                                                   k_clusters=k_clusters)

        model_edit.to(self.device)
        self.logger.log(f"Edited and put model on {model_edit.device} in time {elapsed_from_str(time_edit_start)}")


        predictions = []

        # Reset dataset metrics and set progress timestamp
        self.dataset_metric.reset()
        self.progress.start()

        # Answer tokens: true and false
        # Space before true is important otherwise we will get the wrong token_id
        true_token_ids = tokenizer(" true")
        assert len(true_token_ids["input_ids"]) == 1
        true_token_id = int(true_token_ids["input_ids"][0])

        # Space before false is important otherwise we will get the wrong token_id
        false_token_ids = tokenizer(" false")
        assert len(false_token_ids["input_ids"]) == 1
        false_token_id = int(false_token_ids["input_ids"][0])

        for i in tqdm(range(0, dataset_size)):

            if (i - 1) % 100 == 0 and i > 1:
                # Print partial performance and telemetry data
                self.dataset_metric.print()
                self.progress.print(ex_done=i, ex_left=(dataset_size - i))

            question = dataset[i]["question"]

            # Answer is either 0 (False) or 1 (True)
            answer_ix = dataset[i]["answer"]
            # Given that we do 1-token look up we do the following:
            # - Compute log-prob of the gold token
            # - Compute top-1, top-5 and top-10 accuracies
            if question.strip().endswith(".") or question.strip().endswith("?"):
                # prompted_question = "Is the following claim true or false: " + question.strip() + " The claim is "
                prompted_question = "Consider the following claim: " + \
                                    question.strip() + " Is this claim true or false. The claim is"
            else:
                # prompted_question = "Is the following claim true or false: " + question.strip() + ". The claim is "
                prompted_question = "Consider the following claim: " + \
                                    question.strip() + ". Is this claim true or false. The claim is"
            assert answer_ix in [0, 1]

            inputs = tokenizer(prompted_question, return_tensors="pt").to(self.device)

            with torch.no_grad():
                # Compute log probability of question
                results = model_edit(inputs.input_ids)
                logits = results.logits[0]                                      # question length x vocab
                log_prob = torch.nn.functional.log_softmax(logits, dim=1)       # question length x vocab

                last_token_logprob = log_prob[-1]                               # vocab

                true_logprob = last_token_logprob[true_token_id].item()
                false_logprob = last_token_logprob[false_token_id].item()

                if answer_ix == 1:     # Answer is True
                    answer_log_prob = true_logprob
                    is_correct = true_logprob > false_logprob
                    answer = "true"
                else:               # Answer is False
                    answer_log_prob = false_logprob
                    is_correct = true_logprob < false_logprob
                    answer = "false"

                sorted_logprob, sorted_indices = torch.sort(last_token_logprob, descending=True)

                top_k_logprob = sorted_logprob[:10].detach().cpu().numpy()
                top_k_indices = sorted_indices[:10].detach()

                decoded_tokens = tokenizer.batch_decode(top_k_indices)
                top_k_tokens = [token for token in decoded_tokens]
                assert len(top_k_tokens) == 10

                top_1_acc = float(answer.lower().strip() in [token.lower().strip() for token in top_k_tokens[:1]])
                top_5_acc = float(answer.lower().strip() in [token.lower().strip() for token in top_k_tokens[:5]])
                top_10_acc = float(answer.lower().strip() in [token.lower().strip() for token in top_k_tokens[:10]])

                # Compute log-prob of question and answer
                selected_log_prob = log_prob[:-1, :]  # question - 1 x vocab
                indices = inputs.input_ids[0, 1:].unsqueeze(1)  # question - 1 x 1

                selected_log_prob = torch.gather(selected_log_prob,
                                                 index=indices,
                                                 dim=1)  # question - 1 x 1
                question_log_prob = selected_log_prob.sum().item()
                total_log_prob = question_log_prob + answer_log_prob

                logprob_results = ContextAnswerLogProb(total_log_prob=total_log_prob,
                                                       answer_log_prob=answer_log_prob,
                                                       answer_len=1)

            self.dataset_metric.accept(is_correct=is_correct,
                                       f1pr_score=None,
                                       log_prob_results=logprob_results,
                                       top_k_acc={1: top_1_acc, 5: top_5_acc, 10: top_10_acc})

            if i % 10 == 0:
                print(f"Question: {question} and gold answer {answer}. Predicted top 10 tokens {top_k_tokens}.")

            predictions_ = {
                "ix": i,
                "question": question,
                "prompted-question": prompted_question,
                "gold-answer": answer,
                "gold-answer-ix": answer_ix,
                "generation": top_k_tokens[0],      # We can view the top token as the 1-step generation
                "correct": is_correct,
                "true_logprob": true_logprob,
                "false_logprob": false_logprob,
                "top_1_acc": top_1_acc,
                "top_5_acc": top_5_acc,
                "top_10_acc": top_10_acc,
                "top_10_logprob": top_k_logprob,
                "top_10_tokens": top_k_tokens,
                "f1_score": None,
                "precision": None,
                "recall": None,
                "case-sensitive": self.case_sensitive,        # We ignore case when checking answer
                "white-space-strip": self.strip,              # We ignore white space when checking answer
                "total_logprob": total_log_prob,
                "question_logprob": question_log_prob,
                "answer_logprob": answer_log_prob,
                "answer_length": 1,
                "question_answer_length": inputs.input_ids.shape[1] + 1
            }
            predictions.append(predictions_)

        # Save results and terminate
        return self.terminate_and_save(predictions)

    def terminate_and_save(self, predictions):

        self.logger.log("Saving results. Final Performance is given below:")
        self.dataset_metric.terminate()
        self.dataset_metric.print()

        time_start = time.time()
        # Save predictions
        save_pred_fname = f"{self.save_dir}/{llm_name}-predictions-{args.rate}-{args.dtpts}-{args.lnum}.p"

        with open(save_pred_fname, "wb") as f:
            pickle.dump(predictions, f)

        # Save the summary
        save_summary_fname = f"{self.save_dir}/{llm_name}-result-summary-{args.rate}-{args.dtpts}-{args.lnum}.pkl"

        results = self.dataset_metric.agg_to_dict()
        for k, v in args.__dict__.items():
            results["args/%s" % k] = v

        with open(save_summary_fname, "wb") as f:
            pickle.dump(results, f)

        # Print final numbers and return
        self.logger.log(f"Time taken to store all results {elapsed_from_str(time_start)}")


        return results


if __name__ == '__main__':

    # Step 1: Command line argument
    parser = argparse.ArgumentParser(description='Process Arguments for experiments with GPTJ LLM on CounterFact')

    parser.add_argument('--rate', type=float, default=1, help='rates for intervention')
    parser.add_argument('--dtpts', type=int, default=22000, help='# samples per instruction')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size for evaluation')
    parser.add_argument('--max_len', type=int, default=1, help='maximum length for generation')
    parser.add_argument('--k', type=int, default=10, help='top k for evaluation')
    parser.add_argument('--intervention', type=str, default="rank-reduction",
                        choices=['dropout', 'rank-reduction', 'quantize'], help="what type of intervention to perform")
    parser.add_argument('--lname', type=str, default="None",
                        choices=['k_proj', 'q_proj', 'v_proj', 'out_proj', 'fc_in', 'fc_up', 'fc_out', 'None', 'dont',
                                 "all", "mlp", "attn"],
                        help="provided which type of parameters to effect")
    parser.add_argument('--lnum', type=int, default=24, help='Layers to edit', choices=list(range(-1, 28)))
    parser.add_argument('--model_path',
                        type=str,
                        default="/mnt/data/Llama2/Llama-2-7b-hf",
                        help="Place where model weights are stored")
    parser.add_argument('--home_dir', type=str,
                        default="/mnt/data/iclr2024/fever/gptj_results",
                        help='Directory where the data is')
    parser.add_argument('--dataset_file', type=str,
                        default="/mnt/data/counterfact",
                        help='Directory where the data is')
    parser.add_argument('--num_clusters', type=int,
                        default=1,
                        help='Number of clusters')
    parser.add_argument('--group_via', type=str,
                        default="rows",
                        help='Group k-SVD via rows or cols')
    parser.add_argument('--shuffle', type=bool,
                        default=False,
                        help='Shuffle rows before rank reduction')

    args = parser.parse_args()

    # Step 2: Load model and tokenizer
    llm_name = "GPTJ"
    llm_path = "EleutherAI/gpt-j-6B"
    tokenizer = AutoTokenizer.from_pretrained(llm_path)
    model = GPTJForCausalLM.from_pretrained(
        llm_path,
        revision="float16",
        torch_dtype=torch.float16
    )

    # Step 3: Create save directory and logger
    home_dir = args.home_dir
    dataset_loc = args.dataset_file

    save_dir = f"{home_dir}/{llm_name}/{args.intervention}/{args.lname}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    logger = Logger(save_dir=save_dir, fname=f"{llm_name}-log-{args.lnum}-{args.lname}-{args.rate}.txt")

    # Step 4: Create an experiment
    experiment = GPTJExperiment(save_dir=save_dir, logger=logger)

    logger.log("=" * 50)
    logger.log(f"Created a new Experiment. Model {llm_name}")
    logger.log("=" * 50)

    for k, v in args.__dict__.items():
        logger.log(f">>>> Command line argument {k} => {v}")
    logger.log("=" * 50)

    # Step 5: Read the dataset
    dataset_util = FEVER()
    dataset = dataset_util.get_dataset_no_logger()

    # Step 6: Run intervention

    base_results = None
    best_results = None
    best_lnum = None
    best_lname = None
    best_rate = None

    lnums_list = [27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

    # column_titles = ["lname", "lnum", "rate", "num_clusters", "shuffle", "accuracy", "loss", "n_rows", "d", "groups", "first group shape", "last group shape", "desired_rank", "k", "approx error"]
    column_titles = ["lname", "lnum 1", "lnum 2", "rate", "num_clusters", "k_clusters", "accuracy", "loss", "n_rows", "d", "groups", "first group shape", "last group shape", "desired_rank", "k", "approx error"]

    output = []
    output.append(column_titles[0:8])
    
    for lnums in lnums_list:

        if type(lnums) == list and lnums[0] == -1:
            lnames = ["dont"]
            rates = [9.9]
            num_clusters_list = [1]
            shuffle_list = [False]
        elif type(lnums) != list and lnums == -1:
            lnames = ["dont"]
            rates = [9.9]
            num_clusters_list = [1]
            shuffle_list = [False]
        else:
            lnames = ["fc_in", "fc_out"]
            rates = [1.0, 2.0, 4.0, 6.0, 8.0, 9.0, 9.5, 9.9, 9.95]
            k_clusters_list = [1]
            num_clusters_list = [1, 2, 4, 8, 16]
            shuffle_list = [False]

        for lname in lnames:
            for rate in reversed(rates):
                for num_clusters in num_clusters_list:
                    for shuffle in shuffle_list:
                        for k_clusters in k_clusters_list:
                            
                            if type(lnums) != list:
                                args.lnum = lnums

                            args.lname = lname
                            args.rate = rate
                            args.num_clusters = num_clusters
                            args.shuffle = shuffle
                            # args.intervention = 'quantize' # Doing quantization here
                            model_copy = deepcopy(model)
                            results = experiment.intervene(model=model_copy,
                                                            tokenizer=tokenizer,
                                                            dataset=random.sample(dataset, 100),
                                                            args=args,
                                                            llm_name=llm_name,
                                                            lnums=lnums,
                                                            k_clusters=k_clusters)

                            # results = experiment.validate(predictions, split=0.2)
                            if type(lnums) == list:
                                output.append([lname, lnums[0], lnums[1], rate, num_clusters, k_clusters, results[DatasetMetrics.CORRECTNESS], results[DatasetMetrics.MeanLogProb]])
                            else:
                                output.append([lname, lnums, -1, rate, num_clusters, k_clusters, results[DatasetMetrics.CORRECTNESS], results[DatasetMetrics.MeanLogProb]])

                            arr = np.array(output)
                            df = pd.DataFrame(arr)
                            df.to_csv("gptj_fever_100_points.csv")

    logger.log("Experimented Completed.")
