import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler
import torch.distributed as dist

import random
import json
from tqdm import tqdm

from model import AutoLLM
import os

from param import parse_args
from utils.data_utils import *
from utils.eval import *

import copy

from accelerate import Accelerator
from utils.train_utils import clear_gpu_cache, setup, setup_environ_flags, create_logger
from utils.process_utils import replace_harm, clean_reply, construct_unlearning_dataset
from transformers import default_data_collator

from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import GradScaler

class Finetuner:
    def __init__(self, args, auto_llm, replace_gpt):
        self.args = args
        self.rng = random.Random(args.seed)

        # load model
        self.llm = auto_llm
        self.replace_gpt = replace_gpt

        # init optimizer
        self.optimizer, self.scheduler, self.scaler = self.init_optimizer()

        self.batch_size = args.batch_size

    def construct_out_file(self, cyl_epo, ft_epoch):
        out_dir = f"{self.args.root_path}/result/finetune/{self.args.model_name}"
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        out_file = f"{out_dir}/{self.args.test_dataset}_{self.args.unlearning_type}_ft_epoch_{ft_epoch}_epoch_{cyl_epo}.json"
        return out_file
    
    def construct_log_file(self, cyl_epo):
        out_dir = f"{self.args.root_path}/log/finetune/{self.args.model_name}"
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        out_file = f"{out_dir}/{self.args.test_dataset}_{self.args.unlearning_type}_epoch_{cyl_epo}.log"
        return out_file

    def init_optimizer(self):
        optimizer = optim.AdamW(
                self.llm.model.parameters(),
                lr=self.args.lr,
                weight_decay=0.0,
            )
        scheduler = StepLR(optimizer, step_size=100000, gamma=0.9)
        scaler = GradScaler()
        return optimizer, scheduler, scaler

    def query_llm_reply(self, llm, prompts):
        # iterative create message
        messages = {}
        messages['message'] = []
        for prompt in prompts:
            example = {}
            example['prompt'] = prompt
            example = llm.process_fn(
                    example, prompt_construct_fn=lambda x: x['prompt']
                )
            messages['message'].append(example['message'])

        rslts = llm.generate(messages, temperature=0, max_tokens=800)
        return rslts
    
    def save_model(self, epoch, cyl_epo):
        save_path = f"{self.args.root_path}/save_model/finetune_{self.args.test_dataset}_{self.args.unlearning_type}/{self.args.model_name}_cyl_epo_{cyl_epo}_ft_epo_{epoch}"
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        # save_model = copy.deepcopy(self.llm.model).to("cpu")
        # save_model.merge_and_unload().save_pretrained(save_path)
        self.llm.model.merge_and_unload().save_pretrained(save_path)
        self.llm.tokenizer.save_pretrained(save_path)
        # del save_model

    def train(self, dataset, cyl_epo):
        # contruct data out file
        log_file = self.construct_log_file(cyl_epo+1)
        logger = create_logger(log_file)

        # load test datasets
        if self.args.enable_fsdp:
            local_rank = int(os.environ["LOCAL_RANK"])
            train_sampler = DistributedSampler(
                dataset,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
                shuffle=True,
            )
        else:
            local_rank = None
            train_sampler = torch.utils.data.RandomSampler(dataset)

        dataloader = DataLoader(
            dataset, 
            batch_size=1, 
            collate_fn=default_data_collator, 
            pin_memory=True,
            sampler=train_sampler
        )
        # catch replace_harmful_response
        for epoch in range(self.args.ft_epochs):
            loss_list = []

            with tqdm(
                total=int(len(dataset)/self.batch_size), desc=f'Epoch {epoch + 1}/{self.args.ft_epochs}', unit='batch'
            ) as pbar:

                for step, batch in enumerate(dataloader):
                    # just query gpt when epoch == 0
                    for key in batch.keys():
                        if self.args.enable_fsdp:
                            batch[key] = batch[key].to(local_rank)
                        else:
                            batch[key] = batch[key].to(self.llm.model.device)
                    
                    output = self.llm.model(**batch) 

                    loss = output.loss

                    self.scaler.scale(output.loss).backward()

                    self.scaler.step(self.optimizer)

                    self.scheduler.step()

                    self.scaler.update()

                    self.optimizer.zero_grad()

                    loss_list.append(loss.item())
                    
                    # print(f'[epoch: {epoch} step: {step}] Loss: {loss.item()}')

                    # gradient ascend
                    if self.args.unlearn_alg == "kg_grad_asc":
                        loss = loss * -1

                    pbar.update(self.batch_size)

                logger.info(f'[epoch: {epoch}] Loss: {np.mean(np.array(loss_list))}')
                print(f'[epoch: {epoch}] Loss: {np.mean(np.array(loss_list))}')

            if (epoch + 1) % self.args.save_model_interval == 0 and (epoch + 1) != self.args.ft_epochs:
                self.save_model(epoch + 1, cyl_epo+1)

        self.save_model('final', cyl_epo+1)
        logger.info(f'Final model saved!')

    def construct_generate_inputs(self, llm, prompt):
        examples = []
        # instruction inject
        message = {}
        message['prompt'] = prompt
        message = llm.process_fn(
                message, prompt_construct_fn=lambda x: x['prompt']
            )
        example = message['message']
        examples.append(example)
        
        # tokenizer
        batch_inputs = llm.tokenizer(
            examples, return_tensors='pt', padding = True
        )

        # puts on cuda
        for key in batch_inputs:
            batch_inputs[key] = batch_inputs[key].cuda()

        return batch_inputs
    
    def generate(self):
         # load test datasets
        if self.args.test_dataset == 'harmfulRespDemo':
            with open(f"{self.args.root_path}/data/harmful_resp_demo.json") as f:
                resps_data = json.load(f)
        elif self.args.test_dataset == 'harmfulResp':
            with open(f"{self.args.root_path}/data/final_harmful_resp.json") as f:
                resps_data = json.load(f)
        else:
            raise NameError
    
        test_prompts = []
        for prompt in resps_data:
            test_prompts.append(prompt)

        self.llm.model.eval()

        result = {}

        with tqdm(
                total=int(len(test_prompts))
            ) as pbar:

            for step, prompt in tqdm(enumerate(test_prompts)):

                # tokenizer
                batch_inputs = self.construct_generate_inputs(
                    self.llm, prompt
                )

                outputs = self.llm.generate(
                    batch_inputs
                )

                print(outputs)

                result[prompt] = outputs

                pbar.update(1)

        return result