from torch.nn import CrossEntropyLoss
from transformers import get_scheduler
from tqdm import tqdm
from pathlib import Path
from torch.optim import AdamW
import torch
import os
import wandb
import math
from test_utils import graph_evaluation, phonebook_evaluation

import itertools
## Set model to GPU
from accelerate import Accelerator
from accelerate.logging import get_logger

logger = get_logger(__name__)



def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
            shape [batch_size X sequence_length, num_experts].
        num_experts (`int`, *optional*):
            Number of experts

    Returns:
        The auxiliary loss.
    """
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    # Compute the percentage of tokens routed to each experts
    tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

    # Compute the average probability of routing to these experts
    router_prob_per_expert = torch.mean(routing_weights, dim=0)

    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def ce_loss(inputs, logits, mask, TO_TOKEN):
    # Shift so that tokens < n predict n
    if type(logits) != torch.Tensor:
        logits = logits['logits']
    shift_labels = inputs.contiguous()
    shift_logits = logits.contiguous()
    mask = mask.contiguous().view(-1)

    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(ignore_index=TO_TOKEN['*'], reduction='none')
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return torch.sum(loss*mask)/torch.sum(mask)

def ce_loss_mlp(logits, labels):
        # Calculate per-token loss
        loss_fct = CrossEntropyLoss(reduction='mean')
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        #loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)).to(self.device_id), shift_labels.view(-1).to(self.device_id))
        return loss

 
def train(args,model,train_dataloader,val_dataloader,test_dataloader,tokenizer,TO_TOKEN):

    



    
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    max_train_steps = args.num_epochs * num_update_steps_per_epoch
    max_warmup_steps = int(0.2*max_train_steps)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=max_warmup_steps,
        num_training_steps=max_train_steps,
    )

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)

    # Prepare everything with our `accelerator`.
    if args.task == "graph":
        model, optimizer, train_dataloader, val_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, val_dataloader, test_dataloader, lr_scheduler
        )
    elif args.task == "phone":
        model, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, test_dataloader, lr_scheduler
        )


     # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    max_train_steps = args.num_epochs * num_update_steps_per_epoch
    
    num_update_steps_per_epoch_without_grad_acc = len(train_dataloader)
    num_log_steps = int(num_update_steps_per_epoch_without_grad_acc /10)
    num_eval_steps = int(num_update_steps_per_epoch_without_grad_acc /10) #10 #math.ceil(max_train_steps/10)

    assert num_log_steps > args.gradient_accumulation_steps
    assert num_eval_steps > args.gradient_accumulation_steps

    # Afterwards we recalculate our number of training epochs
    args.num_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    print("***** Running training *****",flush=True)
    print(f"  Num examples = {len(train_dataloader.dataset)}",flush=True)
    print(f"  Num Epochs = {args.num_epochs}",flush=True)
    print(f"  Instantaneous batch size per device = {args.train_batch_size}",flush=True)
    print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}",flush=True)
    print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}",flush=True)
    print(f"  Total optimization steps = {max_train_steps}",flush=True)
    print(f"  Total warmup steps = [{max_warmup_steps}/{max_train_steps}]",flush=True)
    print(f"  Frequency of eval steps = [{num_eval_steps}/{num_update_steps_per_epoch_without_grad_acc}]",flush=True)
    print(f"  Frequency of logging steps = [{num_log_steps}/{num_update_steps_per_epoch_without_grad_acc}]",flush=True)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0
    starting_epoch = 0


    # update the progress_bar if load from checkpoint
    progress_bar.update(completed_steps)

    for epoch in range(starting_epoch, args.num_epochs):
        model.train()
       
        total_loss = 0
        total_balancing_loss = 0
        active_dataloader = train_dataloader
        for step, batch in enumerate(active_dataloader):
            # print(f"STEP {step}",flush=True)
            with accelerator.accumulate(model):
                if args.model == "gmlp":
                    x = batch['input_ids'].to('cuda')
                    y =  batch['label_ids'].to('cuda')#batch['input_ids'][:,-9:].to('cuda')
                else:
                    x = batch['input_ids'][:,:-1].to('cuda')
                    y = batch['input_ids'][:,1:].to('cuda')
                    mask = batch['mask'][:,1:].to('cuda')
                if args.model == "sparse" and  args.router_aux_loss_coef != 0:
                    outputs = model(x, output_router_logits=True)
                    logits = outputs['logits']
                    balancing_loss = outputs['aux_loss']
                    # print(f"BAL {balancing_loss}")
                    # regrew
                else:
                    logits = model(x)
                if args.model == "gmlp":
                    loss = ce_loss_mlp(logits, y)
                else:
                    loss = ce_loss(y, logits, mask, TO_TOKEN)
                # We keep track of the loss at each epoch
                total_loss += loss.detach().float()
                if args.router_aux_loss_coef != 0 and args.model == "sparse":
                    total_balancing_loss += balancing_loss.detach().float()
                    loss += args.router_aux_loss_coef*balancing_loss#load_balancing_loss_func(gate_logits, num_experts=args.num_experts) 
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                completed_steps += 1

            wandb.log({
                   "learning_rate": get_lr(optimizer),
               })    

            if step % num_log_steps == 0:
               print(f"step [{completed_steps}/{max_train_steps}]: train loss: {total_loss/(step+1)}, lr: {get_lr(optimizer)}",flush=True)
               wandb.log({
                   "completed_step": completed_steps,
                   "train/loss": total_loss/(step+1),
                   "learning_rate": get_lr(optimizer),
               })
               if args.router_aux_loss_coef != 0 and args.model == "sparse":
                    wandb.log({
                        "train/balancing_loss": total_balancing_loss/(step+1),
                    })        

            if step % num_eval_steps == 0 or completed_steps == max_train_steps: 
                model.eval()
                if args.task == "graph":
                    if args.model != "gmlp":
                    
                        val_losses = []
                        for val_step, val_batch in enumerate(val_dataloader):
                            x_val = val_batch['input_ids'][:,:-1].to('cuda')
                            y_val = val_batch['input_ids'][:,1:].to('cuda')
                            mask_val = val_batch['mask'][:,1:].to('cuda')
                            with torch.no_grad():
                                logits_val = model(x_val)
                                val_loss = ce_loss(y_val, logits_val, mask_val, TO_TOKEN)
                                val_losses.append(accelerator.gather_for_metrics(loss.repeat(args.eval_batch_size)))

                        val_losses = torch.cat(val_losses)
                        try:
                            eval_loss = torch.mean(val_losses)
                            eval_perplexity = math.exp(eval_loss)
                        except OverflowError:
                            eval_perplexity = float("inf")  
                        
                    
                    
                if completed_steps == max_train_steps:
                    print(f"ARGS\n{args}\n\n",flush=True)
                if args.task == "graph":
                    train_str_acc, train_char_acc = graph_evaluation(args, model, train_dataloader, tokenizer, TO_TOKEN)
                    test_str_acc, test_char_acc = graph_evaluation(args, model, test_dataloader, tokenizer, TO_TOKEN)
                    print(f"step [{completed_steps}/{max_train_steps}]: eval loss: {eval_loss}, eval ppl {eval_perplexity}, train accuracy {train_str_acc}, test accuracy {test_str_acc}, test character {test_char_acc}",flush=True)
                    wandb.log({
                        "completed_step": completed_steps,
                        "eval/loss": eval_loss,
                        "eval/perplexity": eval_perplexity,
                        "train/string_accuracy": train_str_acc,
                        "train/char_accuracy": train_char_acc,
                        "test/string_accuracy": test_str_acc,
                        "test/char_accuracy": test_char_acc,

                    })

                elif args.task == "phone":
                    test_str_acc, test_char_acc = phonebook_evaluation(args, model, test_dataloader, tokenizer, TO_TOKEN)
                    print(f"step [{completed_steps}/{max_train_steps}]: test accuracy {test_str_acc}, test character {test_char_acc}",flush=True)
                    wandb.log({
                        "completed_step": completed_steps,
                        "test/string_accuracy": test_str_acc,
                        "test/char_accuracy": test_char_acc,
                    })
                model.train()

            if completed_steps >= max_train_steps:
                break

    
    
