"""
Code for Train & Test
"""

import os
import math
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm  
import warnings
from scipy.stats import beta, dirichlet, multivariate_normal as mvn
from scipy.special import gammaln, digamma
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.gamma import Gamma
from torch.distributions.kl import kl_divergence as kl_div
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

### [1] Training & Testing for F-EDL
def compute_mu(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True)  
    mu = (alpha + tau * p) / (alpha0 + tau)
    return mu

def compute_var(alpha, p, tau):
    alpha0 = alpha.sum(dim=1, keepdim=True) 
    
    mu = (alpha + tau * p) / (alpha0 + tau)
    term1 = mu * (1 - mu) / (alpha0 + tau + 1)
    term2 = (tau**2) * p * (1 - p) / ((alpha0 + tau) * (alpha0 + tau + 1))
    var = term1 + term2
        
    return var



def train_fedl(model, learning_rate, weight_decay, step_size, num_epochs, trainloader, validloader, num_classes, fix_tau, fix_p, device):
    scaler = torch.cuda.amp.GradScaler()  

    optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
   
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)

    VAL_ACC, VAL_LOSS = [], []
    patience, best_acc, epochs_no_improve = 3,0,0    
    
    model.to(device)
    model.train()
    
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        for x, y in trainloader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            y_oh = F.one_hot(y, num_classes).float().to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                alpha, p, tau = model(x, fix_tau, fix_p)
                mu = compute_mu(alpha, p, tau)
                var = compute_var(alpha, p, tau)

                loss_cls = torch.sum((y_oh - mu) ** 2) + torch.sum(var)
                loss_reg_p = torch.sum((y_oh - p) ** 2)
                loss = loss_cls + loss_reg_p 

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        scheduler.step()

        if epoch % 10 == 0 and epoch > 0:
            model.eval()
            val_loss, correct_v, total_v = 0.0, 0, 0

            with torch.no_grad():
                for x_v, y_v in validloader:
                    x_v, y_v = x_v.to(device), y_v.to(device)
                    y_oh_v = F.one_hot(y_v, num_classes).float().to(device)

                    alpha_v, p_v, tau_v = model(x_v.to(device), fix_tau, fix_p)

                    mu_v = compute_mu(alpha_v, p_v, tau_v)
                    var_v = compute_var(alpha_v, p_v, tau_v)

                    loss_cls_v = torch.sum((y_oh_v - mu_v) ** 2) + torch.sum(var_v)
                    loss_reg_p_v = torch.sum((y_oh_v - p_v) ** 2)
                    val_loss += (loss_cls_v + loss_reg_p_v).item()

                    y_pred_v = mu_v.argmax(1)
                    correct_v += (y_pred_v == y_v).sum().item()
                    total_v += y_v.size(0)

            val_acc = 100 * correct_v / total_v
            VAL_ACC.append(val_acc)
            VAL_LOSS.append(val_loss)

            print(f"Epoch {epoch}: Val Loss = {val_loss:.3f}, Val Acc = {val_acc:.2f}%")

            if val_acc > best_acc + 0.01:
                best_acc = val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print("Early stopping triggered.")
                break


def eval_fedl(model, testloader, fix_tau, fix_p, device):
    
    model.eval()
        
    total_t, correct_t = 0, 0
    
    with torch.no_grad():
        for i, (x_t, y_t) in enumerate(testloader):
            x_t, y_t = x_t.to(device), y_t.to(device)
            alpha_t, p_t, tau_t = model(x_t)

            mu_t = compute_mu(alpha_t, p_t, tau_t)
            
            y_pred_t = torch.argmax(mu_t, dim=1) 

            total_t += y_t.size(0)
            correct_t += (y_pred_t == y_t).sum().item()
            
    test_acc = 100 * correct_t / total_t
    print("Test Accuracy:", test_acc)
    
    return test_acc

### [2] Training & Testing for EDL Methods (EDL, I-EDL, R-EDL, DAEDL)

def train_edl(model, edl_type, learning_rate, step_size, reg_param_kl, reg_param_fisher, lamb, num_epochs, trainloader, validloader, num_classes, device):
    import torch.nn.functional as F
    from torch.distributions import Dirichlet, kl_divergence as kl_div
    from tqdm import tqdm

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
    scaler = torch.cuda.amp.GradScaler()

    VAL_ACC = []
    VAL_LOSS = []
    cnt = 0

    model.to(device)
    model.train()

    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0

        for i, (x, y) in enumerate(trainloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)

            with torch.cuda.amp.autocast():
                logits = model(x)

                if edl_type == "EDL":
                    alpha = 1 + torch.relu(logits)
                elif edl_type == "DAEDL":
                    alpha = 1e-6 + torch.exp(logits)
                elif edl_type == "I-EDL":
                    alpha = 1 + F.softplus(logits)
                elif edl_type == "R-EDL":
                    alpha = 1e-6 + F.softplus(logits) + lamb

                alpha0 = alpha.sum(1).reshape(-1, 1)
                y_oh = F.one_hot(y, num_classes).to(device)
                alpha_tilde = alpha * (1 - y_oh) + y_oh

                loss_cls = torch.sum((y_oh - alpha / alpha0) ** 2)
                loss_var = torch.sum((alpha * (alpha0 - alpha)) / ((alpha0 ** 2) * (alpha0 + 1)))
                loss_kl = kl_div(Dirichlet(1e-6 + alpha_tilde), Dirichlet(torch.ones_like(alpha_tilde))).sum()

                if edl_type in ["EDL", "DAEDL"]:
                    loss = loss_cls + loss_var + reg_param_kl * loss_kl
                elif edl_type == "R-EDL":
                    loss = loss_cls + reg_param_kl * loss_kl
                elif edl_type == "I-EDL":
                    gamma1_alp = torch.polygamma(1, alpha)
                    gamma1_alp0 = torch.polygamma(1, alpha0)
                    loss_fisher = -(torch.log(gamma1_alp).sum(-1) + torch.log(1.0 - (gamma1_alp0 / gamma1_alp).sum(-1))).mean()
                    loss = loss_cls + loss_var + reg_param_kl * loss_kl + reg_param_fisher * loss_fisher

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        scheduler.step()

        if epoch % 10 == 0 and epoch > 0:
            total = 0
            correct = 0
            val_loss = 0

            model.eval()
            with torch.no_grad():
                for i, (x_v, y_v) in enumerate(validloader):
                    x_v, y_v = x_v.to(device), y_v.to(device)
                    logits_v = model(x_v)

                    if edl_type == "EDL":
                        alpha_v = 1 + torch.relu(logits_v)
                    elif edl_type == "DAEDL":
                        alpha_v = 1e-6 + torch.exp(logits_v)
                    elif edl_type == "I-EDL":
                        alpha_v = 1 + F.softplus(logits_v)
                    elif edl_type == "R-EDL":
                        alpha_v = 1e-6 + F.softplus(logits_v) + lamb

                    alpha0_v = alpha_v.sum(1).reshape(-1, 1)
                    y_oh_v = F.one_hot(y_v, num_classes).to(device)
                    alpha_v_tilde = alpha_v * (1 - y_oh_v) + y_oh_v

                    loss_cls_v = torch.sum((y_oh_v - alpha_v / alpha0_v) ** 2)
                    loss_var_v = torch.sum((alpha_v * (alpha0_v - alpha_v)) / ((alpha0_v ** 2) * (alpha0_v + 1)))
                    loss_kl_v = kl_div(Dirichlet(1e-6 + alpha_v_tilde), Dirichlet(torch.ones_like(alpha_v_tilde))).sum()

                    if edl_type in ["EDL", "DAEDL"]:
                        val_loss += loss_cls_v + loss_var_v + reg_param_kl * loss_kl_v
                    elif edl_type == "R-EDL":
                        val_loss += loss_cls_v + reg_param_kl * loss_kl_v
                    elif edl_type == "I-EDL":
                        gamma1_alp_v = torch.polygamma(1, alpha_v)
                        gamma1_alp0_v = torch.polygamma(1, alpha0_v)
                        loss_fisher_v = -(torch.log(gamma1_alp_v).sum(-1) + torch.log(1.0 - (gamma1_alp0_v / gamma1_alp_v).sum(-1))).mean()
                        val_loss += loss_cls_v + loss_var_v + reg_param_kl * loss_kl_v + reg_param_fisher * loss_fisher_v

                    y_pred_v = alpha_v.argmax(1)
                    total += y_v.size(0)
                    correct += (y_pred_v == y_v).sum().item()

            val_acc = 100 * correct / total
            VAL_LOSS.append(val_loss.item())
            VAL_ACC.append(val_acc)

            if len(VAL_ACC) > 2:
                r_acc = (VAL_ACC[-1] - VAL_ACC[-2]) / VAL_ACC[-2]
                r_loss = (VAL_LOSS[-1] - VAL_LOSS[-2]) / VAL_LOSS[-2]

                if r_loss > -0.0001:
                    cnt += 1
                else:
                    cnt = 0

            if cnt > 3:
                break

            print(f'Epoch {epoch}, Validation loss = {val_loss.item():.3f}')
            print(f'Validation Accuracy = {val_acc:.3f}')

            model.train()


            
def eval_edl(model, edl_type, testloader, device):    
    model.eval()
    total = 0
    correct = 0
    
    with torch.no_grad():
        for i, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            
            if edl_type == "EDL":
                alpha_pred = 1 + torch.relu(logits)
                    
            elif edl_type == "DAEDL":
                alpha_pred = 1e-6 + torch.exp(logits)
                
            elif edl_type == "I-EDL":
                alpha_pred = 1 + F.softplus(logits)
                
            elif edl_type == "R-EDL":
                lamb2 = 1  
                alpha_pred = 1e-6 + F.softplus(logits) + lamb2
            
        
            y_pred = alpha_pred.argmax(1)
            
            total += y.size(0)
            correct += (y_pred == y).sum().item()
            
        test_acc = 100 * correct / total
        print("Test Accuracy:", test_acc)
    
    return test_acc


#### [3] Train & Test for Softmax Models (Softmax, Dropout, DDU)
def train_softmax(model, learning_rate, step_size, num_epochs, trainloader, validloader, sm_type, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
    criterion = torch.nn.CrossEntropyLoss()  
    
    VAL_ACC = []
    VAL_LOSS = []
    cnt = 0

    model.to(device)
    model.train()
    
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        
        for i, (x, y) in enumerate(trainloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            logits = model(x)
            loss = criterion(logits, y)  
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        scheduler.step()
        
        if epoch % 10 == 0 and epoch > 0:
            total = 0
            correct = 0
            val_loss = 0
            
            model.eval()
            with torch.no_grad():
                for i, (x_v, y_v) in enumerate(validloader):
                    x_v, y_v = x_v.to(device), y_v.to(device)

                    logits_v = model(x_v)
                    loss_v = criterion(logits_v, y_v)  
                    val_loss += loss_v.item() * x_v.size(0) 

                    y_pred_classes = logits_v.argmax(1)
                    total += y_v.size(0)
                    correct += (y_pred_classes == y_v).sum().item()

            val_acc = 100 * correct / total
            VAL_LOSS.append(val_loss / total)  
            VAL_ACC.append(val_acc)
            
            if len(VAL_ACC) > 2:
                r_acc = (VAL_ACC[-1] - VAL_ACC[-2]) / VAL_ACC[-2]
                r_loss = (VAL_LOSS[-1] - VAL_LOSS[-2]) / VAL_LOSS[-2]

                if r_loss > -0.0001:
                    cnt += 1
                else:
                    cnt = 0
            
            if cnt > 3:
                break
            
            print('Epoch {}, Validation loss = {:.3f}'.format(epoch, val_loss / total)) 
            print('Validation Accuracy = {:.3f}'.format(val_acc))
            
            model.train()


def eval_softmax(model, testloader, sm_type, num_passes, device):
    if sm_type == "Dropout":
        model.train()  
    else:
        model.eval()  
    
    total = 0
    correct = 0

    with torch.no_grad():
        for i, (x, y) in enumerate(testloader):
            x, y = x.to(device), y.to(device)
            
            if sm_type == "Dropout":
                all_probs = []
                for _ in range(num_passes):
                    logits = model(x)
                    probs = F.softmax(logits, dim=1)
                    all_probs.append(probs)
                
                all_probs = torch.stack(all_probs) 
                probabilities = all_probs.mean(dim=0)
            else:
                logits = model(x)
                probabilities = F.softmax(logits, dim=1)
            
            y_pred = probabilities.argmax(1)
                        
            total += y.size(0)
            correct += (y_pred == y).sum().item()
        
        test_acc = 100 * correct / total
        print("Test Accuracy:", test_acc)
    
    return test_acc
