import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
from collections import Counter
import numpy as np
from scipy.special import softmax



def train_model(args,data_loader, model, n_epochs,lr,device,verbose = 0,class_weight = None):

    model.to(device)
    if class_weight is not None:
        print(f'Training model with active class weights {class_weight}')
        class_weight = torch.tensor(class_weight,device=device).float()
    criterion = nn.CrossEntropyLoss(weight=class_weight)
    label_list =[]
    optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    for epoch in tqdm(range(n_epochs),disable = args['ntqdm']):
        num_correct=0
        num_samples = 0
        for x,y in data_loader:
            x = x.to(device = device,dtype = torch.float)
            y = y.to(device = device, dtype = torch.long)
            model.train()
            output = model(x)
            criterion_loss = criterion(output,y)
            optimizer.zero_grad()
            criterion_loss.backward()
            optimizer.step()

            _,preds = output.max(1)
            num_correct += (preds == y).detach().cpu().numpy().sum()
            num_samples += preds.size(0)
            label_list.extend(y.detach().cpu().numpy().tolist())
            # print(output.size(),y.size())
            # break
        
        if verbose >0:
            print(f'Epoch {epoch} : Accuracy {num_correct/num_samples}')
    
    #print(Counter(label_list))
    
    return model

def evaluate_model(data_loader, model,device,argmax = True):
    model.to(device)
    model.eval()
    predictions=[]
    correct_labels = []
    pats=[]
    with torch.no_grad():
        for x,y in data_loader:
            x = x.to(device = device,dtype = torch.float)
            y = y.to(device = device, dtype = torch.long)
            output = model(x)
            if argmax is True:
                _,preds = output.max(1)
                preds = preds.detach().cpu().numpy()
                predictions.extend(preds.tolist())
            else:
                detached_output = output.detach().cpu().numpy()
                predictions.append(detached_output)

            lbl = y.detach().cpu().numpy()
            correct_labels.extend(lbl.tolist())
    
    if argmax is False:
        predictions = np.concatenate(predictions,axis=0)
        predictions = softmax(predictions,axis=1)
    return predictions, correct_labels
