import torch
import pdb
import torch.nn.functional as F
import pdb
from tqdm import tqdm
from collections import defaultdict
import os
import numpy as np
import re

def compression_scheduler(current_step, max_steps, kind='step'):
    """
    Schedule target compression rate during training

    """

    if kind == 'step':
        progress = current_step / max_steps
        if progress <= 0.30:
            return 0.70
        elif progress <= 0.60:
            return 0.60
        elif progress <= 0.90:
            return 0.50  # Second interval
        else:
            return 0.40

    elif kind == 'linear':
        rate = torch.linspace(1.0, 0.40, max_steps)[current_step]
        return rate
    
    elif kind == 'fixed': 
        return 0.0

    else:
        raise NotImplementedError(f"interval_compression_scheduler only implemented for kind=step or linear, but got: {kind}")

def configure_required_grad(model):
    """
    Set which layers requires gradients and which doesn't.

    None of the layers require grad except for trainable singular values
    """
    non_trainable = trainable = 0
    for name, param in model.named_parameters():
        if 'E_train' in name:
            param.requires_grad = True
            trainable += 1
        else:
            param.requires_grad = False
            non_trainable += 1

    print(
        f'Layers that require gradients configure. Number of trainable layers: {trainable}, fraction: {trainable/(trainable+non_trainable): 0.2f}')

def training_step(model, batch, batch_idx, compression_params, pad_token_id, distill_ds, args, compression_calculator):
    """
    One training step of model
    """

    # create inputs and targets
    input_ids = batch['input_ids'][:, :-1].to(model.device)
    attention_mask = batch['attention_mask'][:, :-1].to(model.device)
    labels = batch['input_ids'][:, 1:].clone().to(model.device)
    labels[labels == pad_token_id] = -100

    if not args.distill_mode: 
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
    elif args.distill_mode in ['hs_last']:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
    else:
        raise NotImplementedError(f'Unaccounted distill_mode {args.distill_mode} in training_step forward pass')

    logits = outputs.logits

    if not args.distill_mode: 
        logits_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction='mean', ignore_index=-100)

        with torch.no_grad():
            perplexity = torch.exp(logits_loss)
        distill_loss = torch.tensor(0., device=logits.device)
        hs_loss = distill_kl = 0.

    else: 
        logits_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction='mean', ignore_index=-100)
        with torch.no_grad():
            perplexity = torch.exp(logits_loss)
        logits_loss = 0. 
    
        distill_loss = get_distillation_loss(attention_mask, outputs, distill_ds, args.distill_mode, batch_idx)
        hs_loss, distill_kl = distill_loss, 0 # compat

    tv_loss = 0. 

    if args.compress_loss == 'default':
        compression_loss = compression_calculator.get_compression_loss()
        with torch.no_grad():
            keep_ratio = compression_calculator.get_sv_ratio()

    tv_loss = compression_calculator.get_tv_loss()

    scale_compression = args.scale_compression
    scale_distill = args.scale_distill
    current_param_ratio = compression_calculator.get_compression()
    if  (current_param_ratio - args.target_param_ratio) < 0.0005: 
        scale_compression = 0.

    loss = logits_loss * scale_distill + compression_loss * scale_compression + distill_loss * scale_distill + args.tv_loss * tv_loss

    return loss, logits_loss, distill_loss, distill_kl, hs_loss, compression_loss, perplexity, keep_ratio, tv_loss, current_param_ratio

def get_distillation_loss(attention_mask, outputs, distill_ds, distill_mode, batch_idx):
    if 'hs' in distill_mode:        
        # get hidden states from model and teacher 
        hs1 = outputs.hidden_states[distill_ds['layer_idx1']]
        hs2 = outputs.hidden_states[distill_ds['layer_idx2']]
        teacher_hs1 = distill_ds['teacher_hidden1'].to(hs1.device) 
        teacher_hs2 = distill_ds['teacher_hidden2'].to(hs2.device) 
        
        attention_mask = attention_mask[:, :, None]
        
        # mask hidden states 
        hs1_masked = hs1 * attention_mask
        hs2_masked = hs2 * attention_mask
        teacher_hs1 = teacher_hs1 * attention_mask
        teacher_hs2 = teacher_hs2 * attention_mask
        
        hs_loss = mse_loss_masked(hs1_masked, teacher_hs1, attention_mask) + mse_loss_masked(hs2_masked, teacher_hs2, attention_mask)
        
        distill_loss = hs_loss

    else: 
        raise NotImplementedError(f'Unsupported distill_mode {distill_mode} in add_distillation_loss')

    return distill_loss

def mse_loss_masked(h1, h2, attention_mask):
    """
    Calculate the MSE error between h1 and h2. Avoid the pitfall of performing F.mse over h1*mask, because most
    elements are 0 due to causal mask and leads to low loss.

    Inputs: 
        h1: masked hidden states 
        h2: masked hidden states 
        attention_mask: (bs, seq_len, 1)
    """
    assert len(h1.shape) == 3, "expected hidden dim=3 in mse_loss_masked"
    assert len(attention_mask.shape) == 3, "expected attention mask to be of dim=2"
    assert attention_mask.shape[2] == 1, 'Expected last dim of attention mask to be 1'

    squared_loss = (h1 - h2)**2
    numel = attention_mask.sum() * h1.shape[2]
    mean_squared_loss = squared_loss.sum()/numel
    return mean_squared_loss

def eval_model(model, test_dl, compression_params, pad_token_id, args, compression_calculator):
    """
    Perform evaluation
    """
    model = model.eval()
    metrics = defaultdict(list)
    for batch_idx, batch in enumerate(tqdm(test_dl, desc=f"Evaluating", mininterval=5)):
        if args.distill_mode:
            distill_data_path = os.path.join(args.cache_dir, f"distill_cache/test_{batch_idx}.pt")
            distill_batch = torch.load(distill_data_path)
        else:
            distill_batch = {}

        with torch.no_grad():
            loss, logits_loss, distill_loss, distill_kl, hs_loss, compression_loss, perplexity, keep_ratio, _, current_compression = training_step(model, batch, batch_idx, compression_params, 
                                                                                                                  pad_token_id, distill_batch, args, compression_calculator)
        metrics['loss'].append(loss.item())
        metrics['logits_loss'].append(logits_loss.item() if isinstance(logits_loss, torch.Tensor) else logits_loss)
        metrics['distill_loss'].append(distill_loss.item())
        metrics['distill_logits_loss'].append(distill_kl)
        metrics['compression_loss'].append(compression_loss.item())
        metrics['sv_keep_ratio'].append(keep_ratio)
        metrics['perplexity'].append(perplexity.item())
        metrics['hs_loss'].append(hs_loss.item() if isinstance(hs_loss, torch.Tensor) else hs_loss)
        metrics['compression_ratio'].append(current_compression.item() if isinstance(current_compression, torch.Tensor) else current_compression)
       
        del distill_batch

    for key in metrics:
        metrics[key] = sum(metrics[key]) / len(metrics[key])

    metrics = {f"eval/{key}": value for key, value in metrics.items()}
    torch.cuda.empty_cache()
    model = model.train()
    return metrics


def print_nvidia_smi():
    import subprocess
    
    try:
        # Run nvidia-smi command
        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8', check=True)
        
        print(result.stdout)
    
    except:
        print(f"Error: no nvidia-smi to check gpu util")
        
    
def schedule_distill_scale(current_step, total_steps, should_schedule):
    """
    Schedule compression loss scaling factor for distillation
    """
    
    if not should_schedule:
        return 1.

    def cosine_distribution(current_step, total_steps, amplitude=1, offset=1.2, N=0):
        if current_step < N:
            # For the first N steps, return 1
            return 1.0
        else:
            # Start the cosine distribution after N steps
            effective_step = current_step - N
            effective_total_steps = total_steps - N
            frequency = 10 / effective_total_steps  # Complete cycle at 10% of the remaining steps
            x = effective_step * frequency
            y = amplitude * np.cos(2 * np.pi * x) + offset
            return min(y, 1.0)  # Ensure y doesn't exceed 1

    
    # saturate distillation loss
    N = 250
    if current_step < N:
        return 1. 

    if should_schedule == '1':
        y = cosine_distribution(current_step, total_steps, amplitude=1, offset=1.3, N=N)
    elif should_schedule == '2':
        y = cosine_distribution(current_step, total_steps, amplitude=1, offset=0., N=N)
        y = max(y, 0.50)
    elif should_schedule == '3':
        y = cosine_distribution(current_step, total_steps, amplitude=1, offset=0., N=N)
        y = max(y, 0.30) 
    elif should_schedule == '4':
        y = cosine_distribution(current_step, total_steps, amplitude=1, offset=0, N=N)
        y = min(max(y, 0.25), 0.40) 
    elif should_schedule == '5':
        y = cosine_distribution(current_step, total_steps, amplitude=1, offset=0., N=N)
        y = min(max(y, 0.15), 0.25)

    else:
        raise NotImplementedError(f"should_schedule not supported {should_schedule}, {type(should_schedule)} in  schedule_distill_scale")

    return y
    
def count_parameters(model):
    """
    Calculate the number of parameters in a model and return the count in billions.
    """
    total_params = sum(p.numel() for p in model.parameters())
    total_params_in_billion = total_params / 1e9
    return total_params_in_billion

def push_to_multi_gpu(model):
    """
    Pushes MLP layers with numbers between 3 and 20 in their names to one GPU (device 1),
    and the rest of the layers to another GPU (device 0).
    """

    if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
        device0 = torch.device("cuda:0")
        device1 = torch.device("cuda:1")
    else:
        raise RuntimeError("At least two CUDA devices are required for this operation.")

    # Push modules to the corresponding devices
    for name, module in model.named_modules():
        # Check if 'mlp' is in the name and if any number between 3 and 20 is present
        if  'gate_proj' in name or 'up_proj' in name:
            module.to(device1)  # Push to device 1
            print(f"{name} pushed to device 1 (cuda:1)")
        else:
            module.to(device0)  # Push to device 0
            print(f"{name} pushed to device 0 (cuda:0)")

    return model

def calculate_tv_loss(x):
    """
    Computes the Total Variation (TV) loss for a 1D array.
    
    """    
    assert len(x.shape) == 1, 'exptected input into get_tv_loss to be of dim 1'

    tv = torch.abs(x[1:] - x[:-1])
    tv_loss = tv.mean()
    
    return tv_loss

if __name__ == '__main__':
    from data_utils import get_dataloaders
    from distill_utils import create_distillation_dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import pdb

    model_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    train_dl, _ = get_dataloaders(
        tokenizer,
        dataset_name="wikitext2",
        num_train_samples=256,
        num_test_samples=256,
        batch_size=4,
        random_state=42,
        debug=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    distill_dataset = create_distillation_dataset(model_name, train_dl, distill_mode='hidden_states', device=device)

    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

    for batch_idx, batch in enumerate(train_dl): 
        target_compression = compression_scheduler(batch_idx, 100, kind='step')

        break 

    # training_step(model, batch, batch_idx, target_compression, compression_params, pad_token_id, distill_ds, distill_mode)

