
import torch as th
from tqdm.auto import tqdm
import numpy as np
import random
import os
import sklearn.metrics as metrics
import transformers
import time
from typing import List
from datasets import load_dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from locolms.dataset.beliefbank import *
from locolms.utils.eval import *
from locolms.utils.training import save_checkpoint
from torch.optim.lr_scheduler import CosineAnnealingLR

class Trainer():

    def __init__(
            self, model, 
            constraint_mg,
            lr,
            wandb,
            optimizer,
            checkpoints_path,
            gpu_id,
            config,
            run_parallel,
            dataset,
            val_interval = 1,
        ):
        
        self.lr = lr
        self.model = model
        self.constraint_mg = constraint_mg
        self.wandb = wandb
        self.logging = config["wandb"]
        self.val_interval = val_interval
        self.checkpoints_path = checkpoints_path
        self.optimizer = optimizer

        self.patience = config["patience"]
        self.config = config
        self.num_epochs = config["epochs"]
        self.accumulation_steps = config["accumulation_steps"]
        self.batch_size = 128 if config["model"] == 'allenai/macaw-3b' else config["batch_size"] // config["accumulation_steps"]
        self.constraint_type = config["constraint_type"]
        self.run_parallel = run_parallel
        
        if config["lr_scheduler"]: 
            print("[-] Using lr scheduler: CosineAnnealingLR")
            self.LRscheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=self.num_epochs)
        else: self.LRscheduler = None

        self.gpu_id = gpu_id
        if self.run_parallel: 
            self.model = DDP(model, device_ids=[gpu_id])
            self.model_net = self.model.module
        else: 
            self.model = model
            self.model_net = model

        self.templates, self.uncountables = Facts.get_language_templates(
            templates_path=os.path.join("data", "beliefbank", "templates.json"), 
            uncountables_path=os.path.join("data", "beliefbank", "non_countable.txt"))
        
        print(f"[-] Model: {self.model_net.model_hf_name}")
        print(f"[-] Learning rate: {self.lr}")
        print(f"[-] Batch size: {self.batch_size}")
        print(f"[-] Accumulation steps: {self.accumulation_steps}")
        
        # Setting filenames
        self.path_checkpointing = os.path.join(self.checkpoints_path, f"{self.constraint_type}_{time.strftime('%Y%m%d-%H%M%S')}_{self.model_net.model_hf_name.split('/')[1]}.pth.tar")
        print(f"[!] Setting checkpoint save path to: {self.path_checkpointing}")

        self.path_outputs_log = os.path.join(f"outputs_{time.strftime('%Y%m%d-%H%M%S')}_{self.model_net.model_hf_name.split('/')[1]}.log")
        print(f"[!] Setting outputs log to: {self.path_outputs_log}")

        # Dataset
        if dataset == "beliefbank":
            print("[+] Loading BeliefBank")
            from locolms.dataset.beliefbank import prepare_data
            self.prepare_data = prepare_data
        elif dataset == "conceptnet":
            print("[+] Loading ConceptNet")
            from locolms.dataset.conceptnet import prepare_data
            self.prepare_data = prepare_data

        # Prompting
        self.prompt_format = FORMATS[config["prompt_format"]]

    def is_verbose(self):
        return "verbose" in self.config and self.config["verbose"]

    def log_outputs(self, line):
        with open(self.path_outputs_log, "a") as f:
            f.write(line)

    def log_step(self, log:object, progressbar:object=None):
        """ Log scores to wandb and return log string """
        if self.logging and (self.gpu_id == 0 or not self.run_parallel): self.wandb.log(log)
        printout = ""
        for key, val in log.items(): printout += f"{key}: {val:.2f}; "
        if progressbar is not None:
            progressbar.update(1)
            progressbar.set_description(printout)
        else: print(printout)

    def run_eval(self):
        print(f"[+] Evaluating model.")
        data = self.prepare_data(batch_size=self.batch_size)
        # Test
        if(self.gpu_id == 0 or not self.run_parallel):
            self.score(data=data, mode="test", split="dist_train", prompt_idx=0)
            self.score(data=data, mode="test", split="dist_train", prompt_idx=1)
            self.score(data=data, mode="test", split="dist_test", prompt_idx=0)
            self.score(data=data, mode="test", split="dist_test", prompt_idx=1)

    def run_train(self):
        data = self.prepare_data(batch_size=self.batch_size)
        # Train
        patience = self.patience
        best_scores = None
        for e in range(self.num_epochs):
            self.log_outputs(line=f"===== Epoch {e} =====")
   
            # Training
            print("\n[-] Training...")
            self.epoch(constraints=data["constraints"]["train"], epoch=e)
            if self.LRscheduler is not None: self.LRscheduler.step()

            # Validation
            if (e % self.val_interval == 0) and (self.gpu_id == 0 or not self.run_parallel):
                print(f"[-] Evaluating...")
                avg_scores = self.epoch_eval(data=data, epoch=e)  
            if self.run_parallel: th.distributed.barrier()

            # Early stopping
            if best_scores is None or avg_scores > best_scores:
                # save on best model scores
                best_scores = avg_scores
                save_checkpoint(
                    epoch=self.num_epochs-1, 
                    model=self.model_net.model, 
                    optimizer=self.optimizer, 
                    filename=self.path_checkpointing
                )
                print(f"[-] Checkpoint correctly stored in: {self.path_checkpointing}")
            # elif avg_scores <= best_scores:
            #     if patience == 0: break
            #     patience -= 1

        # Test
        # It doesn't make sense to test an early stopped model
        # print("\n[-] Final test of the model...")
        # if(self.gpu_id == 0 or not self.run_parallel):
        #     self.score(data=data, mode="test", split="dist_train", prompt_idx=0)
        #     self.score(data=data, mode="test", split="dist_test", prompt_idx=0)
        #     self.score(data=data, mode="test", split="dist_train", prompt_idx=1)
        #     self.score(data=data, mode="test", split="dist_test", prompt_idx=1)   
        # if self.run_parallel: th.distributed.barrier()

    def epoch_eval(self, data:object, epoch:int):
        """ 
            Log validation scores and return early stop signal 
        """
        scores1 = self.score(data=data, mode="val", split="dist_train", prompt_idx=0)
        scores2 = self.score(data=data, mode="val", split="dist_train", prompt_idx=1)
        # avg scores
        return np.mean([
            scores1["val/prompt0_dist_train_avg_logic"],
            scores1["val/prompt0_dist_train_avg_factuality"],
            scores2["val/prompt1_dist_train_avg_logic"],
            scores2["val/prompt1_dist_train_avg_factuality"],
        ])

    def epoch(self, constraints, epoch=None):
        """ 
            Combined training epoch 
        """
        e_loss = []
        batch_idx = 0
        progressbar = tqdm((constraints))
        for batch in progressbar:

            prompt_idx = random.choice([0, 1]) # randomize prompt format at each fwd pass

            # Preprocess
            if self.model_net.is_decoder():
                premises = [self.prompt_format[prompt_idx]["prompt"].format(fact=p) for p in batch["antecedent"]]      # B, 1
                hypothesis = [self.prompt_format[prompt_idx]["prompt"].format(fact=p) for p in batch["consequent"]]    # B, 1
                neg_premises = [self.prompt_format[prompt_idx]["prompt"].format(fact=p) for p in batch["neg_antecedent"]]      # B, 1
                neg_hypothesis = [self.prompt_format[prompt_idx]["prompt"].format(fact=p) for p in batch["neg_consequent"]]    # B, 1
            
            elif self.model_net.is_seq2seq():
                premises = [p for p in batch["antecedent"]]      # B, 1
                hypothesis = [p for p in batch["consequent"]]    # B, 1
                neg_premises = [p for p in batch["neg_antecedent"]]      # B, 1
                neg_hypothesis = [p for p in batch["neg_consequent"]]    # B, 1
            
            # Model Inference
            # groundings, symbols and inferred beliefs vary depending on the constraint
            # so the losses change: negation constraint implies applying on both premise 
            # and hypothesis (one at a time)
            log = { "epoch": epoch }
            p1, p2 = self.model_net.prob_formula(s1=premises, s2=hypothesis, label=self.prompt_format[prompt_idx]["label"])  
            
            # Preparing inputs
            if self.constraint_type == "negation":
                ground_facts = None
                constraint_symbols = th.ones((batch["s_antecedent"].shape[0], 2)).to(self.gpu_id)
            else:
                ground_facts = list(zip(batch["g_antecedent"].tolist(), batch["g_consequent"].tolist())) # B, 2
                constraint_symbols = th.stack((batch["s_antecedent"], batch["s_consequent"]), dim=1) # B, 2

            # Supporting all constraints combined
            if self.constraint_mg.need_negations(self.constraint_type):
                p1_not, p2_not = self.model_net.prob_formula(s1=neg_premises, s2=neg_hypothesis, label=self.prompt_format[prompt_idx]["label"])
            else: p1_not, p2_not = None, None

            loss = self.constraint_mg.sl(
                p1=p1, p2=p2, p1_not=p1_not, p2_not=p2_not,
                batch_symbols=constraint_symbols, 
                batch_facts=ground_facts,
                constraint=self.constraint_type
            )

            # Backprop
            loss = loss / self.accumulation_steps
            loss.backward()
            log["train/loss"] = loss.item()
            
            # Gradient accumulation
            if ((batch_idx + 1) % self.accumulation_steps == 0) or (batch_idx + 1 == len(constraints)):
                self.optimizer.step()
                self.optimizer.zero_grad()

            self.log_step(log=log, progressbar=progressbar)
            e_loss.append(loss.item())
            batch_idx += 1

        return th.mean(th.tensor(e_loss))

    def get_beliefs_facts(self, facts, prompt_idx=0):
        """ Get model's beliefs """
        print(f"\n[-] Querying facts for prompt {prompt_idx}...")

        model_beliefs = dict()
        ref_beliefs = dict()

        answ_model_beliefs = []
        answ_model_negated_beliefs = []
        answ_ground_beliefs = []

        trues = []

        if self.is_verbose(): iterator = tqdm(facts)
        else: iterator = facts

        for batch in iterator:
            with th.no_grad():

                # Pre-processing
                facts = preprocess_batch(batch=batch["fact"], prompt_format=self.prompt_format, prompt_idx=prompt_idx)
                negated_facts = preprocess_batch(batch=batch["negated_fact"], prompt_format=self.prompt_format, prompt_idx=prompt_idx)

                # Query
                answers = self.model_net.answer(facts)
                answers_negations = self.model_net.answer(negated_facts)

                # Post-processing
                raw_statements, statements = postprocess_answers(prompts=facts, answers=answers, prompt_format=self.prompt_format, prompt_idx=prompt_idx)
                raw_neg_statements, neg_statements = postprocess_answers(prompts=negated_facts, answers=answers_negations, prompt_format=self.prompt_format, prompt_idx=prompt_idx)

                # Printing outputs 
                for l in list(zip(facts, raw_statements, statements, negated_facts, raw_neg_statements, neg_statements)): 
                    try:
                        self.log_outputs(line=f"{l}\n")
                    except: pass

                # Creating a dictionary of {key: property, value: statement}
                batch_beliefs = batch["belief"].tolist()
                batch_predicates = batch["predicate"]
                for idx, subj in enumerate(batch["subject"]):
                    model_beliefs.setdefault(subj, dict())[batch_predicates[idx]] = statements[idx]
                    ref_beliefs.setdefault(subj, dict())[batch_predicates[idx]] = batch_beliefs[idx]

                # Storing beliefs
                trues += statements
                answ_model_beliefs += statements
                answ_model_negated_beliefs += neg_statements
                answ_ground_beliefs += batch_beliefs
        
        print(f"\n[-] Distribution:")
        print(f"# \"{self.prompt_format[prompt_idx]['label']}\": {sum(trues)}\n# total: {len(trues)}")

        return {
            "beliefs": answ_model_beliefs, 
            "negated_beliefs": answ_model_negated_beliefs, 
            "ground_beliefs": answ_ground_beliefs, 
            "dict_beliefs": model_beliefs, 
            "ref_beliefs": ref_beliefs
        }

    def score_facts(self, facts) -> dict:
        """
            Test model's beliefs against a knowledge base
        """
        f1 = metrics.f1_score(facts["ground_beliefs"], facts["beliefs"])
        negation_consistency = Metrics.negation_consistency(facts["beliefs"], facts["negated_beliefs"])
        return {"f1": f1, "negation_consistency": negation_consistency}

    def score_logic(self, facts, constraints) -> dict:
        """
            Test model's beliefs against logical constraints
        """
        self_consistency = Metrics.consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["dict_beliefs"], constraints=constraints)
        self_inverse_consistency = Metrics.inverse_consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["dict_beliefs"], constraints=constraints)
        self_multihop_consistency = Metrics.multihop_consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["dict_beliefs"], constraints=constraints)
        consistency = Metrics.consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["ref_beliefs"], constraints=constraints)
        inverse_consistency = Metrics.inverse_consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["ref_beliefs"], constraints=constraints)
        multihop_consistency = Metrics.multihop_consistency(model_beliefs=facts["dict_beliefs"], ref_beliefs=facts["ref_beliefs"], constraints=constraints)
        
        satisfiability = Metrics.satisfiability(model_beliefs=facts["dict_beliefs"], constraints=constraints)

        return {
            "self_consistency": self_consistency, 
            "self_inverse_consistency": self_inverse_consistency, 
            "self_multihop_consistency": self_multihop_consistency,
            "satisfiability": satisfiability,
            "consistency": consistency, 
            "inverse_consistency": inverse_consistency, 
            "multihop_consistency": multihop_consistency
        }

    def score(self, data:dict, mode:str, split:str, prompt_idx:int=DEFAULT_PROMPT_FORMAT): 
        """
            Overall model benchmark (logic + facts)
        """
        print(f"\n[-] Computing scores on {split} split in {mode} mode.")
        print(f"[-] Using prompt format: {prompt_idx}")

        beliefs = self.get_beliefs_facts(facts=data["facts"][mode], prompt_idx=prompt_idx)
        
        print(f"\n[{mode}-{split}] Scoring factuality...")
        factuality = self.score_facts(facts=beliefs)

        print(f"\n[{mode}-{split}] Scoring constraints...")
        logic = self.score_logic(facts=beliefs, constraints=data["constraints"][mode])

        scores = {}
        avg_scores = []
        for k, v in factuality.items(): 
            scores[f"{mode}/prompt{prompt_idx}_{split}_{k}"] = v
            avg_scores.append(v)
        scores[f"{mode}/prompt{prompt_idx}_{split}_avg_factuality"] = np.mean(avg_scores)

        avg_scores = []
        for k, v in logic.items(): 
            scores[f"{mode}/prompt{prompt_idx}_{split}_{k}"] = v
            avg_scores.append(v)
        scores[f"{mode}/prompt{prompt_idx}_{split}_avg_logic"] = np.mean(avg_scores)

        # Compute perplexity only during evaluation
        if (mode == "val" or mode == "test") and prompt_idx == DEFAULT_PROMPT_FORMAT:
            print(f"\n[{mode}-{split}] Scoring perplexity...")
            test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            ppl = self.model_net.get_perplexity(data=test["text"], window_size=4096)
            scores[f"{mode}/perplexity"] = ppl

        # Compute ants/cons F1 only at test time
        # if mode == "test":
        #     print(f"\n [{mode}-{split}] Scoring antecedents-consequents")
        #     beliefs_ants = self.get_beliefs_facts(facts=data["facts"][mode][f"{split}_antecedents"], prompt_idx=prompt_idx)
        #     antecedents_f1 = self.score_facts(facts=beliefs_ants)["f1"]
        #     beliefs_cons = self.get_beliefs_facts(facts=data["facts"][mode][f"{split}_consequents"], prompt_idx=prompt_idx)
        #     consequents_f1 = self.score_facts(facts=beliefs_cons)["f1"]
        #     scores[f"{mode}/prompt{prompt_idx}_{split}_antecedents_f1"] = antecedents_f1
        #     scores[f"{mode}/prompt{prompt_idx}_{split}_consequents_f1"] = consequents_f1
        
        # Log all scores
        self.log_step(log=scores)
        return scores

    