import inspect
import importlib
import pickle as pkl
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

import random
import torch
import argparse
from transformers import LlamaForCausalLM, LlamaTokenizer
import os



class TrainCollater:
    def __init__(self,
                 prompt_list=None,
                 llm_tokenizer=None,
                 train=False,
                 terminator="\n",
                 max_step=1):
        self.prompt_list = prompt_list
        self.llm_tokenizer = llm_tokenizer
        self.train=train
        self.terminator = terminator
        self.max_step = max_step
        self.cur_step = 1

    def __call__(self, batch):
        if isinstance(self.prompt_list,list):
            # prompt_list是列表 -> 随机选择一个prompt, len份
            instruction = random.choice(self.prompt_list)
            inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
        else:
            # prompt_list不是列表 -> sample中的instruction_input, len份
            instruction = sample["instruction_input"] if "instruction_input" in sample else None
            inputs_text = instruction if isinstance(instruction, list) else [instruction] * len(batch)
        
        # 过程学习
        thresh_hold = self.cur_step/self.max_step
        p = random.random()
        if p > thresh_hold:
            # 生成过程学习Prompt -> hard task, flag = False
            for i, sample in enumerate(batch):
                input_text=inputs_text[i]
                if '[HistoryHere]' in input_text:
                    insert_prompt=", ".join([seq_title+' [HistoryEmb]' for seq_title in sample['seq_name']])
                    input_text=input_text.replace('[HistoryHere]',insert_prompt)
                if '[CansHere]' in input_text:
                    insert_prompt=", ".join([can_title+' [CansEmb]' for can_title in sample['cans_name']])
                    input_text=input_text.replace('[CansHere]',insert_prompt)    
                    
                # 添加[UserEmb] 
                
                inputs_text[i]=input_text
            flag = False
        else:
            # 生成过程学习Prompt -> easy task, flag = True
            for i, sample in enumerate(batch):
                input_text=inputs_text[i]
                if '[HistoryHere]' in input_text:
                    insert_prompt=", ".join([seq_title+' [PH]' for seq_title in sample['seq_name']])
                    input_text=input_text.replace('[HistoryHere]',insert_prompt)
                if '[CansHere]' in input_text:
                    insert_prompt=", ".join([can_title+' [PH]' for can_title in sample['cans_name']])
                    input_text=input_text.replace('[CansHere]',insert_prompt)    
                    
                # 添加[UserEmb]
                
                inputs_text[i]=input_text
            flag = True
        self.cur_step += 1
        
        # 从batch中提取correct_answer赋值给targets_text
        targets_text = [sample['correct_answer'] for sample in batch]

        if self.train:
            # targets_text后添加terminator
            targets_text=[target_text+self.terminator for target_text in targets_text]
            # 组合inputs_text和targets_text -> inputs_pair
            inputs_pair = [[p, t] for p, t in zip(inputs_text, targets_text)]
            # 编码inputs_pair -> batch_tokens
            batch_tokens = self.llm_tokenizer(
                inputs_pair,
                return_tensors="pt",
                padding="longest",
                truncation=False,
                add_special_tokens=True,
                return_attention_mask=True,
                return_token_type_ids=True)
            # 组合为新的batch
            new_batch={"tokens":batch_tokens,
                       "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
                       "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
                       "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
                       "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
                       "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
                       "flag":flag,
                       }
            
            # new_batch中增添一个user字段/seq4user字段/  或直接使用seq字段生成user_embedding
            
        else:
            # valid/test阶段 
            batch_tokens = self.llm_tokenizer(
                inputs_text,
                return_tensors="pt",
                padding="longest",
                truncation=False,
                add_special_tokens=True,
                # return_attention_mask=True,
                
                # !!!
                return_token_type_ids=True)
            cans_name=[sample['cans_name'] for sample in batch]
            new_batch={"tokens":batch_tokens,
                       "seq":torch.stack([torch.tensor(sample['seq']) for sample in batch], dim=0),
                       "cans":torch.stack([torch.tensor(sample['cans']) for sample in batch], dim=0),
                       "len_seq":torch.stack([torch.tensor(sample['len_seq']) for sample in batch], dim=0),
                       "len_cans":torch.stack([torch.tensor(sample['len_cans']) for sample in batch], dim=0),
                       "item_id": torch.stack([torch.tensor(sample['item_id']) for sample in batch], dim=0),
                       "correct_answer": targets_text,
                       "cans_name": cans_name,
            
                       }
            
        return new_batch

class DInterface(pl.LightningDataModule):

    def __init__(self, 
                 llm_tokenizer=None,
                 num_workers=8,
                 dataset='',
                 **kwargs):
        super().__init__()
        self.num_workers = num_workers
        self.llm_tokenizer=llm_tokenizer
        self.dataset = dataset
        self.kwargs = kwargs
        self.batch_size = kwargs['batch_size']
        self.max_epochs = kwargs['max_epochs']
        # 例: import SteamData...
        self.load_data_module()
        self.load_prompt(kwargs['prompt_path'])

        # 实例化数据集
        # 例: SteamData, 结构为(seq, seq_name, len_seq, seq_str, cans, cans_name, cans_str, len_cans, item_id, item_name, correct_answer)
        self.trainset = self.instancialize(stage='train')
        self.valset = self.instancialize(stage='val')
        self.testset = self.instancialize(stage='test')
        self.max_steps = self.max_epochs*(len(self.trainset)//self.batch_size)//self.num_workers

    def train_dataloader(self):
        return DataLoader(self.trainset,
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=True,
                          drop_last=True,
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=True, max_step=self.max_steps))

    def val_dataloader(self):
        return DataLoader(self.valset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=False,
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))

    def test_dataloader(self):
        return DataLoader(self.testset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=False,
                          
                          drop_last=True,
                          
                          collate_fn=TrainCollater(prompt_list=self.prompt_list,llm_tokenizer=self.llm_tokenizer,train=False))

    # import加载data_module加载数据模块, 设定self.data_module
    def load_data_module(self):
        name = self.dataset
        camel_name = ''.join([i.capitalize() for i in name.split('_')])
        try:
            self.data_module = getattr(importlib.import_module(
                '.'+name, package=__package__), camel_name)
        except:
            raise ValueError(
                f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')

    # 实例化模型
    def instancialize(self, **other_args):
        """ Instancialize a model using the corresponding parameters
            from self.hparams dictionary. You can also input any args
            to overwrite the corresponding value in self.kwargs.
        """
        class_args = inspect.getargspec(self.data_module.__init__).args[1:]
        # inkeys: hparams所有键
        inkeys = self.kwargs.keys()
        args1 = {}
        for arg in class_args:
            if arg in inkeys:
                args1[arg] = self.kwargs[arg]
        args1.update(other_args)
        # args1: args在hparams中有的部分
        return self.data_module(**args1)
    
    # 从指定路径加载prompt_list
    def load_prompt(self,prompt_path):
        if os.path.isfile(prompt_path):
            with open(prompt_path, 'r') as f:
                raw_prompts = f.read().splitlines()
            self.prompt_list = [p.strip() for p in raw_prompts]
            print('Load {} training prompts'.format(len(self.prompt_list)))
            print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
        else:
            self.prompt_list = []