from transformers import AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup, AutoTokenizer, Adafactor
import logging
import pandas as pd
from torch.utils.data import DataLoader
import wandb

from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from deepspeed.ops.adam import DeepSpeedCPUAdam
import pytorch_lightning as pl
from IPython import embed
from collections import defaultdict

import generate_data

class AdditionFlanT5(pl.LightningModule):
    def __init__(self, type, silent=False, batch_size=2, lr=1e-4, accuracy_check_threshold=0.5):
        super().__init__()

        self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
        self.type = type
        self.save_hyperparameters()
        self.accuracy_check_threshold = accuracy_check_threshold
        self.validation_step_outputs = defaultdict(list)

    def set_cur_num_digits(self, num_digits):
        self.cur_num_digits = num_digits

    def initialize(self, model, tokenizer, wandb_logger):
        self.model = model
        self.tokenizer = tokenizer
        self.wandb_logger = wandb_logger

    def forward(self, input_ids, attention_mask, labels=None):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return outputs
    
    def common_step(self, batch, batch_idx):
        outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
        loss = outputs.loss

        return loss
      
    def training_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)     
        self.log("loss", loss, sync_dist=True)
        return loss
    
    def think_through_answer(self, batch):
        
        # Max tokens is set to prevent the model from getting stuck in an infinite loop. 
        raw_output = self.model.generate(input_ids=batch["input_ids"], attention_mask = batch["attention_mask"], max_new_tokens=1000)
        thought_process = []

        for i, tensor in enumerate(raw_output):
            decoded_output = self.tokenizer.decode(batch["input_ids"][i, :], skip_special_tokens=True) + self.tokenizer.decode(tensor, skip_special_tokens=True)
            thought_process.append(decoded_output)

        return thought_process

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        loss = self.common_step(batch, batch_idx)     
        batch_sz = batch["input_ids"].shape[0]
        # self.log("val_loss", loss)
        
        # Don't bother checking accuracy until loss is low
        if loss < self.accuracy_check_threshold:
            thought_process = self.think_through_answer(batch)
        else:
            thought_process = [None] * batch_sz
        
        output = (loss, thought_process, batch["numerical_answer"])
        self.validation_step_outputs[dataloader_idx].append(output)
        return output
    
    def on_validation_epoch_end(self):

        min_acc = 1.0
        for idx, val_out in self.validation_step_outputs.items():
            acc = self._process_validation_epoch_end_for_idx(val_out, num_digits=idx+1)

            is_flash = idx + 1 > self.cur_num_digits

            # Flash accuracy is only needed up to 75%
            if is_flash:
                acc += .25
                # We boost it to make the averaging work out
                # Otherwise it jumps discontinuously to 100%.

            min_acc = min(min_acc, acc)
        
        self.log("min_val_acc", min_acc, sync_dist=True)
        self.validation_step_outputs = defaultdict(list)

    def _process_validation_epoch_end_for_idx(self, val_outs, num_digits):

        if num_digits > self.cur_num_digits:
            num_digits -= self.cur_num_digits
            flash = True
        else:
            flash = False

        # Numerical stability reasons
        total_cnt, overall_loss, accuracy = 1e-7, 0., 0.
        all_thought_arr = []
        printed = False

        for loss, thought_arr, numerical_answers in val_outs:
            overall_loss += loss

            for thought, _ in zip(thought_arr, numerical_answers):

                total_cnt += 1.

                if thought is not None:

                    if flash:

                        if "fast" not in thought:
                            logging.warning("Thought process did not contain 'fast' but is supposed to flash")

                        candidate_answer = generate_data.extract_answer_from_solution(thought, "flash")
                    else:

                        if "fast" in thought:
                            logging.warning("Thought process contained 'fast' but is not supposed to flash")

                        candidate_answer = generate_data.extract_answer_from_solution(thought, self.type)
                    
                    actual_answer = generate_data.extract_answer_from_prompt(thought)

                    if candidate_answer == actual_answer:
                        accuracy += 1
                    elif not self.hparams.silent:
                        logging.info("Validation generated thought: {}".format(thought))
                        logging.info("Validation generated answer: {}".format(candidate_answer))
                        logging.info("Correct answer was {}".format(actual_answer))
                    
                    all_thought_arr.append(thought)
                elif not self.hparams.silent and not printed:
                    printed = True
                    logging.info("Val loss was too high to generate thought")

        # wandb.log({'validation_thought_processes': wandb.Table(dataframe=pd.DataFrame(all_thought_arr))})
        if len(val_outs) > 0:
            if flash:
                self.log("val_flash_acc_{}_digits".format(num_digits), accuracy / total_cnt, sync_dist=True)
                self.log("val_flash_loss_{}_digits".format(num_digits), overall_loss / len(val_outs), sync_dist=True)
            else:
                self.log("val_acc_{}_digits".format(num_digits), accuracy / total_cnt, sync_dist=True)
                self.log("val_loss_{}_digits".format(num_digits), overall_loss / len(val_outs), sync_dist=True)

            if not self.hparams.silent:
                logging.info("{}-digit {} validation accuracy: {}".format(num_digits, "flash" if flash else "", accuracy / total_cnt))
                logging.info("{}-digit {} validation loss: {}".format(num_digits, "flash" if flash else "", overall_loss / len(val_outs)))
            return accuracy / total_cnt

        # Accuracy is 0 if there was more than 1 example
        if total_cnt > 0:
            return 0.
        else:
            return 1.0

    def configure_optimizers(self):

        # TODO: CHANGE BACK IMMEDIATELY
        # optimizer = Adafactor(self.parameters(), lr=self.hparams.lr)
        optimizer = DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.lr)
        scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
        #optimizer = Adafactor(self.parameters(),
            #scale_parameter=False, relative_step=False, warmup_init=False, lr=self.hparams.lr)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}