import os
import sys
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, matthews_corrcoef
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from warmup_scheduler import GradualWarmupScheduler
from scipy.stats import pearsonr, spearmanr
from torch.optim.lr_scheduler import OneCycleLR

try:
    import torch_xla
    import torch_xla.utils.utils as xu
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.data_parallel as dp
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    _torch_tpu_available = True
except:
    _torch_tpu_available = False

sys.path.append("..")

from models.utils import create_masks, ScheduledOptim

def classification_trainer(args, model, train_loader, val_loader, model_checkpoint_name):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09)
    sched = ScheduledOptim(optimizer, d_model=args.d_model, n_warmup_steps=args.n_warmup, lr_mul=args.max_lr)

    #scheduler_steplr = StepLR(optimizer, step_size=4000, gamma=0.1)
    #scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.epochs, after_scheduler=scheduler_steplr)
    lambda_ = args.lambda_

    max_grad_norm = args.max_grad_norm
    best_val_loss = 9999
    bad_epochs = 0

    if args.classification_type == 'binary':
        loss_fn = nn.BCEWithLogitsLoss()
    elif args.classification_type == 'multiclass':
        loss_fn = nn.CrossEntropyLoss()

    for epoch in range(args.epochs):
        if bad_epochs < args.early_stopping_rounds:
            model.train()
            training_loss = 0

            for i, batch in tqdm(enumerate(train_loader)):
                l2_reg = 0
                src = batch['input_ids'].to(device)
                trg = batch['out'].to(device)

                outputs, recon_loss = model(src=src)
                if args.classification_type == 'multiclass' and len(trg.size()) == 2:
                    trg = trg[:,0]

                #for param in model.parameters():
                #    l2_reg += torch.norm(param)

                total_loss = loss_fn(outputs, trg) + lambda_ * l2_reg + lambda_ * recon_loss

                #optimizer.zero_grad()
                sched.zero_grad()
                total_loss.backward()
                
                #print (total_loss, torch.sigmoid(outputs), trg)
                #scheduler_warmup.step(epoch)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                
                #optimizer.step()
                sched.step()

                training_loss += total_loss.detach().cpu().numpy()

            training_loss /= (len(train_loader))

            ################################## validation ##################################
            ################################################################################

            model.eval()

            validation_loss = 0
            all_targets = []
            all_preds = []

            for batch in tqdm(val_loader):
                src = batch['input_ids'].to(device)
                trg = batch['out'].to(device)
                if args.classification_type == 'multiclass' and len(trg.size()) == 2:
                    trg = trg[:,0]
                    
                all_targets.append(trg.detach().cpu().numpy())

                with torch.no_grad():
                    outputs, recon_loss = model(src=src)
                    if args.classification_type == 'binary':
                        all_preds.append(torch.sigmoid(outputs).detach().cpu().numpy())
                    elif args.classification_type == 'multiclass':
                        all_preds.append(outputs.argmax(-1).detach().cpu().numpy())

                total_loss = loss_fn(outputs, trg)

                validation_loss += total_loss.detach().cpu().numpy()

            validation_loss /= len(val_loader)
            all_targets = np.concatenate(all_targets,0)
            all_preds = np.concatenate(all_preds, 0)

            print ("Epoch:{}, Train loss:{} Validation loss:{}, Validation Acc:{}".format(epoch+1, training_loss, \
                                            validation_loss, accuracy_score(all_targets.round(),all_preds.round()), \
                                            ))
            print ("Confusion: {}".format(confusion_matrix(all_targets.round(),all_preds.round())))
                    
            if validation_loss < best_val_loss:
                #torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                best_val_loss = validation_loss
                #artifact.add_file(os.path.join(args.model_save_path, model_checkpoint_name), name='Trained student model')

                bad_epochs = 0
                if args.wandb_logging:
                    #wandb.log({"best_training_loss": training_loss})
                    wandb.log({"best_validation_loss": validation_loss})
                    wandb.log({"best_validation_acc": accuracy_score(all_targets.round(),all_preds.round())})
            else:
                bad_epochs += 1

            if args.wandb_logging:
                wandb.log({"training_loss": training_loss})
                wandb.log({"validation_loss": validation_loss})
                wandb.log({"validation_acc": accuracy_score(all_targets.round(),all_preds.round())})

def generation_trainer(args, model, train_loader, val_loader, model_checkpoint_name):
    if args.use_tpu and _torch_tpu_available:
        device = xm.xla_device()
        args.lr = args.lr * xm.xrt_world_size()
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.use_tpu and _torch_tpu_available:
        train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        val_loader = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09)
    scheduler = OneCycleLR(optimizer, 
                           args.lr, 
                           div_factor=10.0, 
                           final_div_factor=50.0, 
                           epochs=args.epochs,
                           steps_per_epoch=len(train_loader))
    #sched = ScheduledOptim(optimizer, d_model=args.d_model, n_warmup_steps=args.n_warmup, lr_mul=args.max_lr)
    #scheduler_steplr = StepLR(optimizer, step_size=4000, gamma=0.1)
    #scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.epochs, after_scheduler=scheduler_steplr)
    lambda_ = args.lambda_

    max_grad_norm = args.max_grad_norm
    best_val_loss = 9999
    bad_epochs = 0

    loss_fn = nn.CrossEntropyLoss()

    print ("Inside trainer")
    #for i, batch in tqdm(enumerate(train_loader)):
    #    print (batch)

    for epoch in range(args.epochs):
        if bad_epochs < args.early_stopping_rounds:
            model.train()
            training_loss = 0

            for i, batch in tqdm(enumerate(train_loader)):
                l2_reg = 0

                if args.use_tpu and _torch_tpu_available:
                    src = batch['input_ids']
                    trg = batch['output_ids']
                else:
                    src = batch['input_ids'].to(device)
                    trg = batch['output_ids'].to(device)

                decoder_src = trg[:,1:-1]
                target = trg[:,2:].contiguous().view(-1)

                optimizer.zero_grad()

                outputs, recon_loss = model(encoder_src=src, decoder_src=decoder_src)

                #for param in model.parameters():
                #    l2_reg += torch.norm(param)

                print ("Calculating loss")
                total_loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), target, ignore_index=0) #+ \
                #lambda_ * l2_reg + lambda_ * recon_loss

                #sched.zero_grad()
                total_loss.backward()
                
                print ("back propagation")

                #print (total_loss, torch.sigmoid(outputs), trg)
                #scheduler_warmup.step(epoch)
                #torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                
                if args.use_tpu and _torch_tpu_available:
                    xm.optimizer_step(optimizer)
                    #xm.optimizer_step(sched)
                else:
                    optimizer.step()
                    #sched.step()

                print ("optimizer update")

                scheduler.step()
                training_loss += total_loss.detach().cpu().numpy()

                print ("Accumulation")

            training_loss /= (len(train_loader))

            ################################## validation ##################################
            ################################################################################

            model.eval()

            validation_loss = 0
            all_targets = []
            all_preds = []

            for batch in tqdm(val_loader):
                if args.use_tpu and _torch_tpu_available:
                    src = batch['input_ids']
                    trg = batch['output_ids']
                else:
                    src = batch['input_ids'].to(device)
                    trg = batch['output_ids'].to(device)

                decoder_src = trg[:,1:-1]
                target = trg[:,2:].contiguous().view(-1)

                with torch.no_grad():
                    outputs, recon_loss = model(encoder_src=src, decoder_src=decoder_src)

                total_loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), target, ignore_index=0)

                validation_loss += total_loss.detach().cpu().numpy()

            validation_loss /= len(val_loader)

            if args.use_tpu and _torch_tpu_available:
                xm.master_print("Epoch:{}, Train loss:{} Validation loss:{}".format(epoch+1, training_loss, \
                                            validation_loss))
            else:
                print ("Epoch:{}, Train loss:{} Validation loss:{}".format(epoch+1, training_loss, \
                                            validation_loss))
                    
            if validation_loss < best_val_loss:
                #torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                if args.use_tpu and _torch_tpu_available:
                    xm.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                else:
                    torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                best_val_loss = validation_loss
                #artifact.add_file(os.path.join(args.model_save_path, model_checkpoint_name), name='Trained student model')

                bad_epochs = 0
                if args.wandb_logging:
                    #wandb.log({"best_training_loss": training_loss})
                    wandb.log({"best_validation_loss": validation_loss})
            else:
                bad_epochs += 1

            if args.wandb_logging:
                wandb.log({"training_loss": training_loss})
                wandb.log({"validation_loss": validation_loss})

def glue_trainer(args, model, train_loader, val_loader, model_checkpoint_name, eval_metrics):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09)
    #sched = ScheduledOptim(optimizer, d_model=args.d_model, n_warmup_steps=args.n_warmup, lr_mul=args.max_lr)
    #scheduler_steplr = StepLR(optimizer, step_size=4000, gamma=0.1)
    #scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.epochs, after_scheduler=scheduler_steplr)
    lambda_ = args.lambda_

    max_grad_norm = args.max_grad_norm
    best_val_loss = 9999
    bad_epochs = 0

    if args.classification_type == 'binary':
        loss_fn = nn.BCEWithLogitsLoss()
    elif args.classification_type == 'multiclass':
        loss_fn = nn.CrossEntropyLoss()
    elif args.classification_type == 'regression':
        loss_fn = nn.MSELoss()

    for epoch in range(args.epochs):
        if bad_epochs < args.early_stopping_rounds:
            model.train()
            training_loss = 0

            for i, batch in tqdm(enumerate(train_loader)):
                l2_reg = 0
                src = batch['input_ids'].to(device)
                trg = batch['out'].to(device)

                optimizer.zero_grad()

                outputs, recon_loss = model(src=src)
                if args.classification_type == 'multiclass' and len(trg.size()) == 2:
                    trg = trg[:,0]

                #for param in model.parameters():
                #    l2_reg += torch.norm(param)

                total_loss = loss_fn(outputs, trg) + lambda_ * l2_reg + lambda_ * recon_loss
                
                #sched.zero_grad()
                total_loss.backward()
                
                #print (total_loss, torch.sigmoid(outputs), trg)
                #scheduler_warmup.step(epoch)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                
                optimizer.step()
                #sched.step()

                training_loss += total_loss.detach().cpu().numpy()

            training_loss /= (len(train_loader))

            ################################## validation ##################################
            ################################################################################

            model.eval()

            validation_loss = 0
            all_targets = []
            all_preds = []

            for batch in tqdm(val_loader):
                src = batch['input_ids'].to(device)
                trg = batch['out'].to(device)
                if args.classification_type == 'multiclass' and len(trg.size()) == 2:
                    trg = trg[:,0]
                    
                all_targets.append(trg.detach().cpu().numpy())

                with torch.no_grad():
                    outputs, recon_loss = model(src=src)
                    if args.classification_type == 'binary':
                        all_preds.append(torch.sigmoid(outputs).detach().cpu().numpy())
                    elif args.classification_type == 'multiclass':
                        all_preds.append(outputs.argmax(-1).detach().cpu().numpy())
                    else:
                        all_preds.append(outputs.detach().cpu().numpy())

                total_loss = loss_fn(outputs, trg)

                validation_loss += total_loss.detach().cpu().numpy()

            validation_loss /= len(val_loader)
            all_targets = np.concatenate(all_targets,0)
            all_preds = np.concatenate(all_preds, 0)

            if args.classification_type != 'regression':
                print ("Epoch:{}, Train loss:{} Validation loss:{}, Validation Acc:{}".format(epoch+1, training_loss, \
                                                validation_loss, accuracy_score(all_targets.round(),all_preds.round()), \
                                                ))
                print ("Confusion: {}".format(confusion_matrix(all_targets.round(),all_preds.round())))
            else:
                print ("Epoch:{}, Train loss:{} Validation loss:{}".format(epoch+1, training_loss, \
                                                validation_loss))

            if validation_loss < best_val_loss:
                #torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                torch.save(model.state_dict(), os.path.join(args.model_save_path, model_checkpoint_name))
                best_val_loss = validation_loss
                #artifact.add_file(os.path.join(args.model_save_path, model_checkpoint_name), name='Trained student model')

                bad_epochs = 0
                if args.wandb_logging:
                    #wandb.log({"best_training_loss": training_loss})
                    wandb.log({"best_validation_loss": validation_loss})
                    for metric in eval_metrics:
                        if metric == 'accuracy':
                            wandb.log({"best_validation_acc": accuracy_score(all_targets.round(),all_preds.round())})
                        elif metric == 'f1':
                            wandb.log({"best_validation_f1": f1_score(all_targets.round(),all_preds.round())})
                        elif metric == 'mathews':
                            wandb.log({"best_validation_mathews": matthews_corrcoef(all_targets.round(),all_preds.round())})
                        elif metric == 'pearson':
                            wandb.log({"best_validation_pearson": pearsonr(all_targets[:,0],all_preds[:,0])[0]})
                        elif metric == 'spearman':
                            wandb.log({"best_validation_spearman": spearmanr(all_targets[:,0],all_preds[:,0])[0]})
            else:
                bad_epochs += 1

            if args.wandb_logging:
                wandb.log({"training_loss": training_loss})
                wandb.log({"validation_loss": validation_loss})
                for metric in eval_metrics:
                    if metric == 'accuracy':
                        wandb.log({"validation_acc": accuracy_score(all_targets.round(),all_preds.round())})
                    elif metric == 'f1':
                        wandb.log({"validation_f1": f1_score(all_targets.round(),all_preds.round())})
                    elif metric == 'mathews':
                        wandb.log({"validation_mathews": matthews_corrcoef(all_targets.round(),all_preds.round())})
                    elif metric == 'pearson':
                        wandb.log({"validation_pearson": pearsonr(all_targets[:,0],all_preds[:,0])[0]})
                    elif metric == 'spearman':
                        wandb.log({"validation_spearman": spearmanr(all_targets[:,0],all_preds[:,0])[0]})

            for metric in eval_metrics:
                if metric == 'accuracy':
                    print({"validation_acc": accuracy_score(all_targets.round(),all_preds.round())})
                elif metric == 'f1':
                    print({"validation_f1": f1_score(all_targets.round(),all_preds.round())})
                elif metric == 'mathews':
                    print({"validation_mathews": matthews_corrcoef(all_targets.round(),all_preds.round())})
                elif metric == 'pearson':
                    print({"validation_pearson": pearsonr(all_targets[:,0],all_preds[:,0])[0]})
                elif metric == 'spearman':
                    print({"validation_spearman": spearmanr(all_targets[:,0],all_preds[:,0])[0]})
