import os
import csv
import pickle
from dataclasses import dataclass
from contextlib import nullcontext
from tqdm import tqdm
import math

import time
import tiktoken
import torch
import numpy as np
from copy import deepcopy

from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.eval_utils import read_txt, read_csv, finalize
from utils.agg_utils import AverageMeter
from utils.num_utils import logmeanexp, scatter_normmeanexp
from scorer import ScorerManager

import subprocess

def show_gpu(msg):
    """
    ref: https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4
    """
    def query(field):
        return(subprocess.check_output(
            ['nvidia-smi', f'--query-gpu={field}',
                '--format=csv,nounits,noheader'], 
            encoding='utf-8'))
    def to_int(result):
        return int(result.strip().split('\n')[0])
    
    used = to_int(query('memory.used'))
    total = to_int(query('memory.total'))
    pct = used/total
    print('\n' + msg, f'{100*pct:2.1f}% ({used} out of {total})')    

@dataclass
class MuOptimParams:
    lr: float = 0.5
    batch_size: int = 128
    max_iters: int = 1000
    max_len: int = 512
    min_len: int = 0
    num_target_samples: int = 512
    num_proposal_samples: int = 512
    min_err: float = 1e-3
    weight_decay: float = 0
    cond_len: int = 0
    top_k: int = 0
    top_p: float = 1.0
    temperature: float = 1.0

@dataclass
class ProposalParams:
    device: str
    dtype: str
    ckpt_path: str
    top_k: int
    top_p: float
    temperature: float


class DaemonDecoding:
    def __init__(self, model_name, ckpt_path, device, dtype):

        self.device = device
        self.model_name = model_name

        self.dtype = dtype

        device_type = 'cuda' if 'cuda' in device else 'cpu'
        self.ptdtype_map = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}
        self.ptdtype = self.ptdtype_map[dtype]
        self.ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype)


        # load customized model
        # TODO: support huggingface model
        #if "hf" in config_cls: # e.g. "hf_gpt2-large"
        self.model = self._load_model(AutoModelForCausalLM, ckpt_path)

        # default tokenizer is openai's tiktoken (most efficient)
        self.tokenizer = self._load_tokenizer(AutoTokenizer, ckpt_path)
        #else:
        #    self.model = self._load_basic_gpt(model_cls, config_cls, ckpt_path)
        self.eos_id = self.tokenizer.eos_token_id
        self.bos_id = self.tokenizer.bos_token_id

        self.model.eval()
        if self.dtype == "float16":
            self.model.half()
        self.model.to(self.device)


    def _load_model(self, model_cls: AutoModelForCausalLM, ckpt_path: str):
        return model_cls.from_pretrained(ckpt_path)
    
    def _load_tokenizer(self, tokenizer_cls: AutoTokenizer, ckpt_path: str):
        return tokenizer_cls.from_pretrained(ckpt_path)

    def _load_openai_tokenizer(self, bpe_type="gpt2"):
        # tokenizer.decode(list_of_ints)
        # tokenizer.encode(strs)
        tokenizer = tiktoken.get_encoding(bpe_type)
        return tokenizer

    def _read_and_tokenize(self, data, max_length=None):
        if type(data) == str:
            if data[:-3] == "csv":
                _data = read_csv(data)

            elif data[:-3] == "txt":
                _data = read_txt(data)
                
            elif data[:-3] == "bin":
                _data = pickle.load(open(data, "rb"))

            data = _data 

        assert(type(data) == list), f"expected list of strings or token ids list but get {type(data)}"
        
        #print(data[0])
        if type(data[0]) == str:
            tokenized_data = [self.tokenizer.encode(x) for x in tqdm(data, desc="tokenizing...")]
        else:
            tokenized_data = data
        
        if max_length != None:
            tokenized_data = [x[:max_length] for x in tokenized_data]

        return tokenized_data

    def process_output_scores(self, scores, padded_sequences):
        scores = torch.stack(scores, 1).to(self.device)
        # bsz * mc_num, mc_len, vocab_size

        lprobs = torch.nn.LogSoftmax(-1)(scores).gather(-1, padded_sequences.unsqueeze(-1)).squeeze(-1)
        lprobs = lprobs.masked_fill(padded_sequences.eq(self.eos_id), 0.)
        # bsz * mc_num, mc_len

        return lprobs.sum(-1) # bsz * mc_num

    def compute_optimal_mu_match_prefix(self, data, scorer_configs, optim_params: MuOptimParams, data_name=None, cache_dir="import_samp_cache", err_fn="rmsre", no_prefix_weight=False, mu_sign_mask=None):
        tokenized_data = self._read_and_tokenize(data, max_length=optim_params.max_len)

        manager = ScorerManager(scorer_configs, self.tokenizer)
        scorers, names = manager.get_scorers_with_names()

        n_eval = optim_params.num_target_samples
        target_moment = []
        eval_data = tokenized_data[:n_eval]
        for i in tqdm(range(0, n_eval, optim_params.batch_size), desc="calc gt scores"):
            batched_samples = torch.LongTensor([[self.bos_id] + x[:optim_params.max_len] + [self.eos_id] * (optim_params.max_len - len(x)) for x in eval_data[i:i+optim_params.batch_size]])
            scores = manager.get_scores_batch(scorers, batched_samples)
            target_moment.append(scores)
        target_moment = torch.cat(target_moment, dim=0).mean(0)

        print("target moment: \n" + "\n" .join([f"\t{names[_i]}: {target_moment[_i].item():.4f}" for _i in range(len(names))]))

        # sample from proposal network and calculate 
        n_sample = optim_params.num_proposal_samples
        bsz = optim_params.batch_size
        max_len = optim_params.max_len
        min_len = optim_params.min_len
        cond_len = optim_params.cond_len

        os.makedirs(cache_dir, exist_ok=True)
        cache_name = f"{self.model_name}-cond{cond_len}-min{min_len}-max{max_len}-n_sample{n_sample}-n_eval{n_eval}_k{optim_params.top_k}p{optim_params.top_p}t{optim_params.temperature}.bin"
        if data_name is not None:
            cache_name = os.path.join(cache_dir, data_name + "-" + cache_name)
        
        
        # seed for sample and mu optimization
        torch.manual_seed(42)
        torch.cuda.manual_seed(42)

        print(cache_name)

        if os.path.exists(cache_name):
            all_completions, prefix_lprobs = torch.load(cache_name, map_location=self.device)

        else:
            all_completions = []
            prefix_lprobs = []
            for i in range(0, n_eval, bsz):
                show_gpu(f"gpu usage before {i} step: ")
                print(f"sampling {i // bsz} / {n_eval // bsz} batches...")
                inputs = torch.LongTensor([[self.bos_id] + x[:cond_len] for x in eval_data[i:i+bsz]]).to(self.device)
                attention_mask = torch.BoolTensor([[1] * (1 + cond_len) for x in eval_data[i:i+bsz]]).to(self.device)

                with torch.no_grad():
                    with self.ctx:
                        outs = self.model.generate(inputs, 
                                    attention_mask = attention_mask,
                                    max_length = max_len + 1,
                                    min_length = min_len + 1,
                                    do_sample = True,
                                    top_k = optim_params.top_k,
                                    top_p = optim_params.top_p,
                                    temperature = optim_params.temperature,
                                    use_cache = True,
                                    pad_token_id = self.tokenizer.pad_token_id,
                                    bos_token_id = self.tokenizer.bos_token_id,
                                    eos_token_id = self.tokenizer.eos_token_id,
                                    early_stopping = True,
                                    output_scores = False,
                                    return_dict_in_generate = True,
                                    num_return_sequences = n_sample // n_eval)
                        
                        show_gpu(f"gpu usage after {i} step: ")

                        padded_sequences = self.pad_sequences(outs.sequences, pad_idx=self.eos_id, eos_idx=self.eos_id)
                        
                        #completions = padded_sequences.view(bsz, n_sample // n_eval, -1)#.cpu().detach()#torch.stack(completions, dim=1)
                        all_completions.append(padded_sequences)


                        # get prefix lprobs
                        if inputs.size(1) > 1:
                            prefix_lprobs_seq = self.get_seq_lprobs(self.model, inputs, 1)
                            prefix_lprobs.append(prefix_lprobs_seq.repeat(n_sample // n_eval))
                        else:
                            prefix_lprobs.append(torch.zeros(inputs.size(0) * n_sample // n_eval).to(self.device))

            all_completions = torch.cat(all_completions, dim=0)
            prefix_lprobs = torch.cat(prefix_lprobs, dim=0)

            print(f"Saving cache to {cache_name}...")
            torch.save([all_completions, prefix_lprobs], cache_name)

            with open(cache_name[:-4] + ".csv", "w") as f:
                writer = csv.writer(f, delimiter=',')
                for line in all_completions:
                    writer.writerow([self.tokenizer.decode(finalize(line.tolist(), eos_idx=self.eos_id, bos_idx=self.bos_id))])
        
        all_scores = []
        for batched_completions in tqdm(all_completions.split(bsz, dim=0), desc="calc sample scores"):
            #batched_completions = self.pad_sequences(batched_completions, pad_idx=self.eos_id, eos_idx=self.eos_id)
            batched_scores = manager.get_scores_batch(scorers, batched_completions).to(self.device)
            all_scores.append(batched_scores)
        all_scores = torch.cat(all_scores, dim=0) # n_eval, n_scores


        print("sample moment: \n" + "\n" .join([f"\t{names[_i]}: {all_scores.mean(0)[_i].item():.4f}" for _i in range(len(names))]))

        
        target_moment = target_moment.to(self.device)
        #mu = torch.randn(len(scorers)).to(self.device)
        mu = torch.zeros(len(scorers)).to(self.device)
        #mu[6] = 10
        mu.requires_grad = True

        max_iters = optim_params.max_iters
        lr = optim_params.lr
        min_err = optim_params.min_err

        # choose optimizer
        optimizer = torch.optim.Adam([mu], lr=lr, betas=(0.9, 0.95))

        if mu_sign_mask is not None:
            mu_sign_mask = torch.LongTensor(mu_sign_mask).to(self.device)
        else:       
            mu_sign_mask = torch.zeros_like(mu).long()
            

        

        for iteration in range(max_iters):
            sample_moment = 0

            # compute sample moment
            logws = all_scores @ mu # n_eval

            if not no_prefix_weight:
                logws = logws + prefix_lprobs
            sample_moment = (torch.nn.Softmax(-1)(logws).unsqueeze(-1) * all_scores).sum(-2) # n_eval, n_fs
            
            if err_fn == "rmsre":
                delta_mu = (sample_moment - target_moment) / target_moment
            elif err_fn == "rse":
                delta_mu = (sample_moment - target_moment)


            relu = torch.nn.ReLU()
            ineq_delta = relu(delta_mu * - mu_sign_mask.to(delta_mu.dtype))
            eq_delta = delta_mu * mu_sign_mask.eq(0).to(delta_mu.dtype)
            delta_mu = ineq_delta + eq_delta
            #constraints_loss = (eq_loss + ineq_loss) / mu.size(0)

            if mu_sign_mask.ne(0).float().sum() > 0:
                mu_loss = relu(mu * - mu_sign_mask.to(delta_mu.dtype)).sum()
            else:
                mu_loss = 0            


            # update mu with SGD            
            #if err_fn == "sre":
            
            #    error = #(delta_mu - delta_mu.mean()).pow(2).sum()
            #elif err_fn == "se":
                
            error = delta_mu.pow(2).mean().sqrt()#(delta_mu - delta_mu.mean()).pow(2).sum()
            #elif err_fn == "mse":
                
            #    error = (delta_mu - delta_mu.mean()).pow(2).mean()
            #elif err_fn == "rmse":
                
            #    error = (delta_mu - delta_mu.mean()).pow(2).mean().sqrt()

            error = error + mu_loss

            #error = torch.norm(delta_mu) + wd_loss
            error.backward()
            optimizer.step()


            optimizer.zero_grad()

            if iteration % 1000 == 0:
                #mu = mu - lr * delta_mu
                print(f"iter: {iteration})")# \t constraint error = {constraints_loss:.4f} | mu sign error = {mu_loss:.6f}")
                print(f"\ttotal error: {error:.4f} | mu sign error = {mu_loss:.6f}")
                print(f"\tmu: {mu.tolist()}")
                print(f"\t|mu|: {torch.norm(mu).item()}")
                print(f"\tapprox moment: {sample_moment.tolist()}")
                print(f"\ttarget moment: {target_moment.tolist()}")


            if error < min_err:
                    
                print(f"optimization finished (error < {min_err})")
                print("\nFinal:\n")
                print(f"iter: {iteration})")
                print(f"\tmu: {mu.tolist()}")
                print(f"\t|mu|: {torch.norm(mu).item()}")
                print(f"\tapprox moment: {sample_moment.tolist()}")
                print(f"\ttarget moment: {target_moment.tolist()}")
                break
        
        return dict([(name, mu_i.item()) for name, mu_i in zip(names, mu)])


    def compute_perplexity(self, eval_data, batch_size, max_length: int, scorer_configs: dict, optimal_mu: dict, temperature: float = None):
        if temperature == None:
            temperature = 1.0
        tokenized_data = self._read_and_tokenize(eval_data, max_length=max_length)
        print(f"{len(tokenized_data)} samples to be evaluated")

        torch.manual_seed(42)

        manager = ScorerManager(scorer_configs, self.tokenizer)
        scorers, names = manager.get_scorers_with_names()

        optimal_mu = torch.FloatTensor([optimal_mu[name] for name in names]).to(self.device)

        original_loss, improved_loss, seq_lengths = [], [], []
        for i in tqdm(range(0, len(tokenized_data), batch_size)):
            print(f"evaluating {i // batch_size} / {len(tokenized_data) // batch_size} batch...")
            inputs = torch.LongTensor([[self.bos_id] + x[:-1] + (max_length - len(x)) * [self.eos_id] for x in tokenized_data[i:i+batch_size]]).to(self.device)
            labels = torch.LongTensor([x + (max_length - len(x)) * [-100] for x in tokenized_data[i:i+batch_size]]).to(self.device)
            attention_mask = torch.BoolTensor([[1] * max_length for x in eval_data[i:i+batch_size]]).to(self.device)

            # calculate scores
            
            batch_scores = []

            batch_scores = manager.get_scores_batch(scorers, inputs).to(self.device)
            
            # forward language model
            with torch.no_grad():
                with self.ctx:
                    outs = self.model(inputs, attention_mask=attention_mask)
                    logits = outs.logits / temperature

                    loss_per_pos = torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1)).view(logits.size(0), -1)#.cpu()
                    seq_loss = loss_per_pos.sum(-1).float()
                    seq_length = labels.ne(-100).long().sum(-1)#.cpu()

                print(seq_loss)
                original_loss.append(seq_loss)
                seq_lengths.append(seq_length)
            
            # calculate improved loss
            ebm_logits = (batch_scores * optimal_mu.unsqueeze(0)).sum(1)
            res_ebm_logits = seq_loss - ebm_logits
            improved_loss.append(res_ebm_logits)
        
        original_loss = torch.cat(original_loss, dim=0)
        improved_loss = torch.cat(improved_loss, dim=0)
        seq_lengths = torch.cat(seq_lengths, dim=0)

        partition = torch.logsumexp(-improved_loss, -1)

        avg_ori_loss = original_loss.mean().item()
        avg_imp_loss = improved_loss.mean().item() + partition.item() / improved_loss.size(0)

        print(f"Original sentence loss: {(avg_ori_loss):.4f}")
        print(f"Improved sentence loss: {(avg_imp_loss):.4f}")

        avg_ori_loss_tok = original_loss.sum().item() / seq_lengths.sum().float().item()
        avg_imp_loss_tok = (improved_loss.sum().item() + partition.item()) / seq_lengths.sum().float().item()

        print(f"Original token ppl: {math.exp(avg_ori_loss_tok):.4f} nats")
        print(f"Improved token ppl: {math.exp(avg_imp_loss_tok):.4f} nats")

        return {"ori_sent_loss": avg_ori_loss, "imp_sent_loss": avg_imp_loss, "ori_tok_ppl": math.exp(avg_ori_loss_tok), "imp_tok_ppl": math.exp(avg_imp_loss_tok)}

    
    def pad_sequences(self, token_ids, pad_idx, eos_idx=50256):
        pad_mask = token_ids.eq(eos_idx).float().cumsum(-1).eq(2)
        return token_ids.masked_fill(pad_mask, pad_idx)


    def get_seq_lprobs(self, model, token_ids, cond_len, attention_mask=None):

        inputs = token_ids[:, :-1].contiguous()
        labels = token_ids.masked_fill(token_ids.eq(self.eos_id), -100)[:, 1:].contiguous()
        labels[:, :cond_len - 1] = -100
        #labels = labels[:, 1:].contiguous()
        #print(labels)
        if attention_mask is None:
            attention_mask = torch.ones_like(inputs).bool().to(inputs.device)

        logits = model(inputs, attention_mask=attention_mask).logits

        lprobs_per_pos = - torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1)).view(logits.size(0), -1)
        #print(lprobs_per_pos)

        lprobs_seq = lprobs_per_pos.sum(-1)

        return lprobs_seq

    
    def sample_importance_resample(self, prompts, scorer_configs: dict, optimal_mu: dict, max_len: int, mc_num: int, mc_len: int, top_k, top_p, temperature, corr_reduce = False, seq_top_k = 0, no_IS = True, **model_kwargs):
        tokenized_prompts = self._read_and_tokenize(prompts)

        torch.manual_seed(42)

        manager = ScorerManager(scorer_configs, self.tokenizer)
        scorers, names = manager.get_scorers_with_names()

        optimal_mu = torch.FloatTensor([optimal_mu[name] for name in names]).to(self.device, self.ptdtype)
        mc_len = max_len if mc_len is None else mc_len


        durations = []
        durations_gen = []
        total_outputs = []
        for i, prompt in enumerate(tqdm(tokenized_prompts)):
            start = time.time()
            

            inputs = torch.LongTensor([[self.bos_id] + prompt]).to(self.device)
            attention_mask = torch.BoolTensor([[1] * (len(prompt) + 1)]).to(self.device)
            model_kwargs["attention_mask"] = attention_mask
            model_kwargs["past_key_values"] = None

            cur_len = inputs.size(1)

            with torch.no_grad():
                with self.ctx:
                    
                    # ===================================================================
                    # 1. sample \tilde{x} from the proposal model q(x)
                    proposal_outs = self.model.generate(inputs, 
                                    attention_mask=attention_mask,
                                    max_length = max_len,
                                    num_return_sequences = mc_num,
                                    do_sample = True,
                                    top_k = top_k,
                                    top_p = top_p,
                                    temperature = temperature,
                                    use_cache = True,
                                    pad_token_id = self.tokenizer.pad_token_id,
                                    bos_token_id = self.tokenizer.bos_token_id,
                                    eos_token_id = self.tokenizer.eos_token_id,
                                    early_stopping = True,
                                    return_dict_in_generate = True,
                                    output_scores = not no_IS)
                    
                    proposal_seqs = proposal_outs.sequences
                    print(proposal_seqs.size())
                    padded_sequences = self.pad_sequences(proposal_seqs, pad_idx=self.eos_id, eos_idx=self.eos_id).to(self.device)
                    padded_completions = padded_sequences[:, cur_len:]
                    # bsz * mc_num, mc_len

                    end_gen = time.time()
                    duration_gen = end_gen - start
                    print(f"sampling {i} / {len(tokenized_prompts)} takes: {duration_gen} s")
                    durations_gen.append(duration_gen)

                    # ===================================================================
                    # 3. calculate scores by evaluation metrics and resample distribution
                    flat_fs = manager.get_scores_batch(scorers, padded_sequences).to(self.device, self.ptdtype)
                    # bsz * mc_num, num_fs
                    scores = flat_fs @ optimal_mu
                    # bsz * mc_num

                    end = time.time()
                    duration = end - start
                    print(f"sampling + scoring {i} / {len(tokenized_prompts)} takes: {duration} s")
                    durations.append(duration)

                    if not no_IS:
                        proposal_logits = proposal_outs.scores
                        proposal_logits = torch.stack(proposal_logits, 1).to(self.device)
                        # bsz * mc_num, mc_len, vocab_size

                        proposal_lprobs = torch.nn.LogSoftmax(-1)(proposal_logits).gather(-1, padded_completions.unsqueeze(-1)).squeeze(-1)
                        proposal_lprobs = proposal_lprobs.masked_fill(padded_completions.eq(self.eos_id), 0.)
                        # bsz * mc_num, mc_len

                        proposal_lprobs_seq = proposal_lprobs.sum(-1)
                        # bsz * mc_num
                        print(proposal_lprobs_seq)

                        # ===================================================================
                        # 2. evaluate using the target model p(x)
                        attention_mask = torch.cat([attention_mask, torch.ones(attention_mask.size(0), padded_sequences.size(1) - cur_len - 1).bool().to(self.device)], dim=-1)
                        attention_mask = attention_mask.repeat(1, mc_num).view(-1, attention_mask.size(1))
                        #print(attention_mask.size())
                        target_lprobs_seq = self.get_seq_lprobs(self.model, padded_sequences, cur_len, attention_mask)
                        print(target_lprobs_seq)

                        # reweight
                        scores = scores + target_lprobs_seq - proposal_lprobs_seq
                    
                    
                    
                    #for s, input_i, f in zip(scores, padded_sequences, flat_fs):
                    #    print(s.item(), f.tolist())
                    #    print(self.tokenizer.decode(input_i.cpu().tolist()))
                    
                    reweighted_scores = (scores).view(-1, mc_num)

                    if seq_top_k > 0:
                        _seq_top_k = min(reweighted_scores.size(-1), seq_top_k)
                        topk_scores, _ = reweighted_scores.topk(k=_seq_top_k, dim=-1)
                        reweighted_scores[reweighted_scores < topk_scores[:, [-1]]] = -float("inf")

                    resample_dist = torch.nn.Softmax(-1)(reweighted_scores)

                    if corr_reduce:
                        reduced_dist = resample_dist / (1 - resample_dist)
                        resample_dist = reduced_dist / reduced_dist.sum(-1, keepdim=True)

                    # bsz, mc_num
                    print(resample_dist)

                    resample_ids = torch.multinomial(resample_dist, num_samples=1)
                    # bsz, 1
                    print(resample_ids)

                    resampled_seqs = torch.gather(padded_sequences, 0, resample_ids.repeat(1, padded_sequences.size(1)))
                    # bsz, max_len

                    

                    

            #for token_ids, pmt_len in zip(resampled_seqs, pmt_lens):
            #total_outputs.append(token_ids[max_pmt_len-pmt_len:].cpu().tolist())
            total_outputs.append(resampled_seqs[0].cpu().tolist())
        
        avg_duration = sum(durations) / len(durations)
        avg_duration_gen = sum(durations_gen) / len(durations_gen)
        print(f"average duration for sampling: {avg_duration_gen} s")
        print(f"average duration for sampling + scoring: {avg_duration} s")
        detok_outputs = []
        for token_ids in total_outputs:
            clean_token_ids = finalize(token_ids, eos_idx=self.eos_id, bos_idx=self.bos_id)
            # strip the first <|endoftext|>
            detok_outputs.append(self.tokenizer.decode(clean_token_ids))

        return detok_outputs



