import argparse
import torch
import os
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from heads_as_modules import *
from teneva import sample
import numpy as np
from datasets import load_dataset

import psutil
import GPUtil

def get_mem_usage():
    # Get system memory information
    system_memory = psutil.virtual_memory()
    available_memory = system_memory.available
    total_memory = system_memory.total

    # Convert to GB
    available_gb = available_memory / (1024 ** 3)
    total_gb = total_memory / (1024 ** 3)

    # Get GPU memory information if available
    gpus = GPUtil.getGPUs()
    if gpus:
        gpu = gpus[0]  # Assuming we're interested in the first GPU
        gpu_memory_total = gpu.memoryTotal / 1024  # Convert MB to GB
        gpu_memory_used = gpu.memoryUsed / 1024   # Convert MB to GB
        gpu_memory_free = gpu_memory_total - gpu_memory_used

        print(f"System Memory: {available_gb:.2f} GB available out of {total_gb:.2f} GB total")
        print(f"GPU Memory: {gpu_memory_free:.2f} GB available out of {gpu_memory_total:.2f} GB total")
    else:
        print(f"Available memory: {available_gb:.2f} GB out of {total_gb:.2f} GB total")


def get_mem_usage():
    ...

import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from transformers_check import CombinedModel, batchify, shift_batch, show, init, train


def create_autoregressive_mask(seq_length):
    """
    Create an attention mask for autoregressive modeling.
    
    Args:
    seq_length (int): The length of the input sequence.
    
    Returns:
    torch.Tensor: A boolean tensor of shape (seq_length, seq_length) where True values 
                  allow attention and False values prevent attention.
    """
    # Create a lower triangular matrix
    mask = torch.tril(torch.ones(seq_length, seq_length, dtype=torch.bool))
    
    return mask

CORRECT_SAMPLING = True

class CombinedModelUpgraded(CombinedModel):
    @torch.no_grad()
    def generate_w_speculative(
        self, 
        idx, 
        attention_mask=None,
        max_new_tokens=500,
        tokenizer=None,
        temperature=1.0, 
        top_k=0.9,
        top_p=None,
        generation_config=None, 
        return_stats=False
    ):
        n_generated_tokens = 0 
        matched_tokens_list = [] 

        device = idx.device

        print(idx.shape, 'idx')

        while n_generated_tokens < max_new_tokens:
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size-self.d:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size-self.d:]
            attention_mask_cond = create_autoregressive_mask(idx_cond.shape[1]).to(idx_cond.device)
            
            if not CORRECT_SAMPLING:
                raise ValueError('No experiments with wrong sampling on PyCode')
                logits_q, _ = self(idx_cond)
                if top_k:
                    logits_q = [f_top_k(l, thres=top_k) for l in logits_q] 
                samples = []
                for l in logits_q:
                    probs = nn.functional.softmax(l / temperature, dim=-1)
                    print(probs.shape, 'probs default')
                    samples.append(torch.multinomial(probs, num_samples=1).squeeze(-1))
            else:
                pred_dict, _ = self(idx_cond, attention_mask=attention_mask_cond)
                log_w = pred_dict['log_w'].cpu()
                log_cores = pred_dict['log_cores']

                log_cores = [l.cpu() for l in log_cores]

                unnorm_expert_probs = log_w

                samples = []
                logits_q = []
                for i in range(len(log_cores)):
                    logits = log_cores[i] + unnorm_expert_probs.expand(*log_cores[i].shape)
                    logits = torch.logsumexp(logits, dim=-1)
                    logits_q.append(logits.unsqueeze(0))
                    probs = nn.functional.softmax(logits / temperature, dim=-1).unsqueeze(0)
                    samples.append(torch.multinomial(probs, num_samples=1).squeeze(-1))

                    # update unnorm_expert_probs
                    unnorm_expert_probs += log_cores[i][samples[-1]].reshape(*unnorm_expert_probs.shape)


            idxs_next = torch.stack(samples, dim=1)

            get_mem_usage()
            #now we need to check new generated tokens
            idx_cands = torch.cat((idx, idxs_next.to(device)), dim=1)
            autoregressive_mask = create_autoregressive_mask(idx_cands.shape[1]).to(idx_cands.device)
            logits_p, _ = self.forward_for_check(
                idx_cands, 
                attention_mask=autoregressive_mask, 
                check_k=self.d + 1
            )

            logits_p = logits_p
            logits_p = [l.cpu() for l in logits_p]

            if top_k:
                logits_p = [f_top_k(l, thres=top_k) for l in logits_p]

            prob_p = [safe_div(l, temperature).softmax(dim = -1) for l in logits_p] 
            prob_q = [safe_div(l, temperature).softmax(dim = -1) for l in logits_q] 
            # p has one more token
            prob_p_alligned, prob_p_last = prob_p[:-1], prob_p[-1]

            
            prob_p_for_token = torch.gather(
                torch.stack(prob_p_alligned).squeeze(1), 
                -1, 
                torch.transpose(idxs_next, 0, 1)
            ).squeeze(-1).cpu()
            prob_q_for_token = torch.gather(
                torch.stack(prob_q).squeeze(1), 
                -1, 
                torch.transpose(idxs_next, 0, 1)
            ).squeeze(-1).cpu()
            

            r = torch.rand(len(prob_p_for_token))

            accepted_tokens = 0
            for i in range(len(prob_p_for_token)):
                # we accept first token always because we know, that it is correct
                if (i == 0) or (r[i] < prob_p_for_token[i] / prob_q_for_token[i]):
                    accepted_tokens += 1
                else:
                    break
            # print('accepted', accepted_tokens, 'tokens')

            if accepted_tokens < len(prob_p_for_token):
                prob_last = torch.maximum(
                    prob_p[accepted_tokens] - prob_q[accepted_tokens], 
                    torch.zeros_like(prob_p[accepted_tokens])
                )
            else:
                prob_last = prob_p_last
            sample_last = torch.multinomial(prob_last, num_samples=1)[0]

            accepted_samples = samples[:accepted_tokens] + [sample_last]
            accepted_samples = torch.stack(accepted_samples, dim=1)

            idx = torch.cat((idx, accepted_samples.to(device)), dim=1)
            
            matched_tokens_list.append(accepted_tokens)
            n_generated_tokens += accepted_tokens + 1

            torch.cuda.empty_cache()                

        if return_stats:
            return idx, matched_tokens_list
        else:
            return idx
        
    def forward_for_check(self, idx, attention_mask=None, targets=None, with_w_norm=True, check_k=None):
        outputs = self.base_model(idx, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        probs = self.new_head(last_hidden_state, targets=targets, with_w_norm=with_w_norm, check_k=check_k)
        return probs


def init(device, dim=2, is_default=False, checkpoint_path=None, rank=None):
    tokenizer = AutoTokenizer.from_pretrained("Daoguang/PyCodeGPT")
    model_base = AutoModelForCausalLM.from_pretrained("Daoguang/PyCodeGPT")

    input_dim = model_base.config.hidden_size
    output_dim = tokenizer.vocab_size

    if is_default: 
        head = DefaultHead(input_dim, output_dim)
    else:
        head = CPHead(input_dim, output_dim, n_tokens=dim, r=rank)
    
    model = CombinedModelUpgraded(model_base, head).to(device)

    if checkpoint_path:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded model from checkpoint: {checkpoint_path}")
    
    return model, tokenizer

def calculate_first_head_loss(model, input_ids, attention_mask, targets):
    # tragets[0] is of the shape (1, seq_len)
    with torch.no_grad():
        outputs = model.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]

        preds, _ = model.new_head(last_hidden_state, check_k=input_ids.shape[-1])

        preds = torch.stack(preds).squeeze(1)
        
        targets_f = targets[0].squeeze(0)
        
        loss_fn = nn.NLLLoss()
        
        # # Calculate loss for the first head
        # log_cores = model.new_head._build_core(0, last_hidden_state, targets)
        # log_w = torch.log(model.new_head.lm_head_weight(last_hidden_state).reshape(-1, model.new_head.r))
        
        
        return loss_fn(preds, targets_f)

@torch.no_grad()
def main():
    parser = argparse.ArgumentParser(description="Evaluate model on prefixes and training data")
    parser.add_argument('--from-pretrained', type=str, required=True, help='path to the checkpoint file')
    parser.add_argument('--dim', type=int, default=2, help='dimension of head')
    parser.add_argument('--rank', type=int, default=None, help='rank for CP head')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use')
    parser.add_argument('--num-train-samples', type=int, default=10, help='number of training samples to evaluate')
    parser.add_argument('--head', type=str, default='default', help='head type, subject of the later extension') 
    parser.add_argument('--temp', type=float, default=1.0, help='sampling temperature')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    is_default = args.head == 'default'


    model, tokenizer = init(
        device, 
        dim=args.dim, 
        is_default=is_default, 
        checkpoint_path=args.from_pretrained, 
        rank=args.rank,
    )
    model.eval()

    # Load the training dataset
    dataset = load_dataset("codeparrot/github-code", streaming=True, split="train", data_dir='../.gh_code_data', cache_dir="../.hf_cache")
    dataset = dataset.shuffle(seed=42).filter(lambda x: x['language'] == 'Python')

    prefixes = [
        "def fibonacci(n):",
        "class BinaryTree:",
        "def quicksort(arr):",
        "import numpy as np",
        "def calculate_mean(numbers):",
        "class Node:",
        "def binary_search(arr, target):",
        "import pandas as pd",
        "def merge_sort(arr):",
        "class Graph:"
    ]

    max_seq_length = 2048
    model.block_size = max_seq_length
    model.d = args.dim

    losses = []
    generated_texts = []
    matched_tokens_stats = []

    model.eval()
    get_mem_usage()
    print("Generating text with speculative decoding...")
    matched_tokens_stats = []
    for prefix in tqdm(prefixes * 10, desc="Generating"):
        inputs = tokenizer([prefix], return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)

        xs, attn_mask = inputs['input_ids'], inputs['attention_mask']

        model.to(device)
        xs = xs.to(device)
        attn_mask = attn_mask.to(device)
    
        
        get_mem_usage()
        generated_text, matched_tokens = model.generate_w_speculative(
            xs, 
            attention_mask=attn_mask, 
            max_new_tokens=100, 
            tokenizer=tokenizer,
            temperature=args.temp,
            top_k=0.9,
            top_p=None,
            generation_config=None,
            return_stats=True
        )

        torch.cuda.empty_cache()    

        #print(generated_text)
        print(matched_tokens)
        matched_tokens_stats += matched_tokens
        #raise ValueError('stop here')

    # Calculate unique values and their normalized counts in matched_tokens_stats
    unique_tokens, counts = np.unique(matched_tokens_stats, return_counts=True)
    normalized_counts = counts / len(matched_tokens_stats)

    print("\nUnique Matched Tokens and Their Normalized Frequencies:")
    for token, norm_count in zip(unique_tokens, normalized_counts):
        print(f"Token {token}: {norm_count:.4f}")
    print(*normalized_counts, sep=', ')

    print("\nCalculating losses on training data...")
    for sample in tqdm(dataset.take(args.num_train_samples), total=args.num_train_samples, desc="Calculating losses"):
        
        inputs = tokenizer([example['code'] for example in [sample]], return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)
        if is_default:
            labels = inputs['input_ids'][:, 1:].to(device)
            seq_len = labels.shape[1]
        else:
            labels, seq_len = shift_batch(inputs['input_ids'], dim=args.dim)
            labels = [elem.to(device) for elem in labels]
        xs, attn_mask = inputs['input_ids'], inputs['attention_mask']
        xs = xs[:, :seq_len]
        attn_mask = attn_mask[:, :seq_len]
        
        loss = calculate_first_head_loss(model, xs.to(device), attn_mask.to(device), targets=labels)
        
        losses.append(loss.item())

    # Print stats
    print("\nEvaluation Results:")
    print(f"Average First Head Loss on Training Data: {np.mean(losses):.4f}")
    print(f"Standard Deviation of Loss: {np.std(losses):.4f}")
    print("\nSpeculative Decoding Stats:")
    print(f"Average Matched Tokens: {np.mean([np.mean(mts) for mts in matched_tokens_stats]):.2f}")

    # print("\nGenerated Text Samples:")
    # for i, (prefix, text, mts) in enumerate(zip(prefixes, generated_texts, matched_tokens_stats), 1):
    #     print(f"\n{i}. Prefix: {prefix}")
    #     print(f"   Generated: {text[:100]}...")  # Print first 100 characters
    #     print(f"   Avg Matched Tokens: {np.mean(mts):.2f}, Max: {max(mts)}")

if __name__ == '__main__':
    main()