import torch
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
from utils.data.data_utils import create_prompt_dataset
from utils.data.data_collator import DataCollator
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.nn.functional as F
import json
import os
import time
from evaluations import eval_FOMC
from transformers import GenerationConfig
generation_config = GenerationConfig(
    temperature=0.1,
    do_sample=True,
    num_return_sequences=1
)


class CL_Base_Model:
    def __init__(self,
                 model,
                 tokenizer,
                 optimizer,
                 train_task_list,
                 eval_task_list,
                 test_task_list,
                 args):
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.train_task_list = train_task_list
        self.eval_task_list = eval_task_list
        self.test_task_list = test_task_list
        self.args = args
        
        
    def perplexity_evaluation(self, eval_dataloader, device):
        self.model.eval()
        losses = 0
        for step, batch in enumerate(eval_dataloader):
            # implementation, batch = {k: v.to(device) for k, v in batch.items()}
            del batch['sources']
            batch = to_device(batch, device)
            with torch.no_grad():
                outputs = self.model(**batch, use_cache=False)
            loss = outputs.loss
            losses += loss.float()
        losses = losses / (step + 1)
        try:
            perplexity = torch.exp(losses)
        except OverflowError:
            perplexity = float("inf")
        try:
            perplexity = get_all_reduce_mean(perplexity).item()
        except:
            pass
        return perplexity


    """def train_one_task(self, task, i_task, epochs):
        if self.args.local_rank == -1:
            device = torch.device("cuda")
        else:
            torch.cuda.set_device(self.args.local_rank)
            device = torch.device("cuda", self.args.local_rank)
        
        #### TRAIN ####
        train_dataloader = self.train_task_list[task]
        eval_dataloader = self.eval_task_list[task]
        total_steps = epochs * len(train_dataloader)
        progress_bar = tqdm(total=total_steps, leave=True, disable=(self.args.global_rank != 0))
        for epoch in range(epochs):
            print_rank_0(
                f"Beginning of Epoch {epoch+1}/{epochs}, Total Micro Batches {len(train_dataloader)}",
                self.args.global_rank)
            self.model.train()

            for step, batch in enumerate(train_dataloader):
                del batch['sources']
                batch = to_device(batch, device)
                outputs = self.model(**batch, use_cache=False)
                loss = outputs.loss
                # Update the description to include current step and loss, if needed
                if self.args.global_rank == 0:
                    # Update the progress bar
                    progress_bar.update(1)
                    description = f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}"
                    progress_bar.set_description(description, refresh=False)

                self.model.backward(loss)
                # Correct gradient accumulation steps are handled withing the deepspeed engine's backward call.
                self.model.step()


            # Evaluate perplexity on the validation set.
            # print_rank_0(
            #     f"***** Evaluating perplexity, Epoch {epoch+1}/{epochs} *****",
            #     self.args.global_rank)
            # perplexity = self.perplexity_evaluation(eval_dataloader, device)
            # print_rank_0(f"ppl: {perplexity}", self.args.global_rank)
            # self.model.tput_timer.update_epoch_count()"""
    
    def train_one_task(self, task, i_task, epochs):
        device = torch.device("cuda") if self.args.local_rank == -1 else torch.device("cuda", self.args.local_rank)
        if self.args.local_rank != -1:
            torch.cuda.set_device(self.args.local_rank)

        train_dataloader = self.train_task_list[task]
        eval_dataloader = self.eval_task_list[task]
        total_micro_batches = epochs * len(train_dataloader)

        # tqdm progress bar
        progress_bar = tqdm(total=total_micro_batches, leave=True, disable=(self.args.global_rank != 0))

        self.model.train()
        for epoch in range(epochs):
            print_rank_0(f"Beginning of Epoch {epoch+1}/{epochs}, Total Micro Batches {len(train_dataloader)}",
                            self.args.global_rank)

            for step, batch in enumerate(train_dataloader):
                del batch['sources']
                batch = to_device(batch, device)

                outputs = self.model(**batch, use_cache=False)
                loss = outputs.loss

                # backward and step
                self.model.backward(loss)
                self.model.step()

                # update progress bar per micro-batch
                if self.args.global_rank == 0:
                    progress_bar.update(1)
                    progress_bar.set_description(f"Epoch {epoch+1}, Micro-batch {step+1}, Loss: {loss.item():.4f}", refresh=True)

                # print loss per real optimizer step (after gradient accumulation)
                if self.args.global_rank == 0 and (step + 1) % self.model.gradient_accumulation_steps() == 0:
                    real_step = (step + 1) // self.model.gradient_accumulation_steps()
                    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

        progress_bar.close() 

    def train_continual(self):
            for i_task, task in enumerate(self.train_task_list):
                self.train_one_task(task, i_task, int(self.args.num_train_epochs[i_task]))
                self.save_model(i_task)

    
    def save_model(self, round):
        if self.args.output_dir is not None:
            print_rank_0('saving model to ' + self.args.output_dir + "/" + str(round) + '...', self.args.global_rank)

        if self.args.global_rank == 0:
            save_hf_format(self.model, self.tokenizer, self.args, sub_folder=str(round))

        if self.args.zero_stage == 3:
            # For zero stage 3, each gpu only has a part of the model, so we need a special save function
            save_zero_three_model(self.model,
                                  self.args.global_rank,
                                  self.args.output_dir,
                                  zero_stage=self.args.zero_stage,
                                  sub_folder=str(round))
        print_rank_0('Successfully saving model after round {}'.format(round), self.args.global_rank)
