import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from sklearn.preprocessing import SplineTransformer, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score
from itertools import combinations
from tqdm.auto import tqdm
from pathlib import Path
import time
import copy
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple
import warnings
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from dataclasses import dataclass, field, asdict
from tqdm.auto import tqdm
import os
import numpy as np
import argparse

warnings.filterwarnings("ignore")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

@dataclass
class NAMConfig:
	input_dim: int
	case: str
	index_list: list
	hidden_sizes: list = field(default_factory=lambda: [64, 32])
	use_exu: bool = True
	dropout: float = 0.1
	use_batchnorm: bool = False
	batch_size: int = 128
	lr: float = 3e-4
	l2: float = 1e-4
	n_epochs: int = 200
	grad_clip: float = 1.0
	early_stopping_patience: int = 20
	ckpt_dir: str = "checkpoints"
	tqlm: bool = True 

class ExU(nn.Module):
    def __init__(self, in_features): super().__init__(); self.weight = nn.Parameter(torch.randn(in_features))
    def forward(self, x): exp_weight = torch.exp(self.weight.clamp(-10, 10)); return (exp_weight - 1) * torch.where(x > 0, x, torch.exp(x) - 1)

class FeatureNN(nn.Module):
    def __init__(self, input_dim=1, hidden_sizes=[64,32], use_exu=True, dropout=0.1, use_batchnorm=False):
        super().__init__()
        self.layers = nn.ModuleList()
        last_dim = input_dim
        for h in hidden_sizes:
            self.layers.append(nn.Linear(last_dim, h))
            if use_batchnorm: self.layers.append(nn.BatchNorm1d(h))
            self.layers.append(ExU(h) if use_exu else nn.ReLU())
            if dropout > 0: self.layers.append(nn.Dropout(dropout))
            last_dim = h
        self.layers.append(nn.Linear(last_dim, 1))
        for m in self.modules():
            if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_in")

    def forward(self, x):
        if x.dim() == 1: x = x.unsqueeze(1)
        for layer in self.layers: x = layer(x)
        return x

class NAM(nn.Module):
    def __init__(self, config: NAMConfig):
        super().__init__()
        self.index_list = config.index_list
        self.feature_nns = nn.ModuleList([
            FeatureNN(len(indices), config.hidden_sizes, config.use_exu, config.dropout, config.use_batchnorm)
            for indices in self.index_list
        ])
        self.bias = nn.Parameter(torch.zeros(1))
        self.config = config

    def forward(self, x):
        terms = []
        for indices, net in zip(self.index_list, self.feature_nns):
            col = x[:, indices] if len(indices) > 1 else x[:, indices[0]].unsqueeze(1)
            out = net(col)
            if torch.isnan(out).any() or torch.isinf(out).any():
                out = torch.zeros_like(out)
            terms.append(out)
        result = self.bias + torch.cat(terms, dim=1).sum(1)
        result = torch.where(torch.isnan(result) | torch.isinf(result), torch.zeros_like(result), result)
        return result

    def save_checkpoint(self, path, optimizer=None, scheduler=None, epoch=None, best_val_loss=None):
        state = {
            "model_state_dict": self.state_dict(),
            "config": asdict(self.config)
        }
        if optimizer: state["optimizer_state_dict"] = optimizer.state_dict()
        if scheduler: state["scheduler_state_dict"] = scheduler.state_dict()
        if epoch is not None: state["epoch"] = epoch
        if best_val_loss is not None: state["best_val_loss"] = best_val_loss
        torch.save(state, path)
        print(f"Checkpoint saved to {path}")

    @staticmethod
    def load_checkpoint(path, device="cpu"):
        checkpoint = torch.load(path, weights_only=False, map_location=device)
        config = NAMConfig(**checkpoint["config"])
        model = NAM(config).to(device)
        model.load_state_dict(checkpoint["model_state_dict"])
        print(f"Checkpoint loaded from {path}")
        return model, config

class NAMTrainer:
	def __init__(self, config: NAMConfig, model=None):
		self.config = config
		#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		self.device = torch.device("cpu")
		self.model = model or NAM(config)
		self.model.to(self.device)
		self.best_val_loss = float("inf")
		self.best_epoch = None
		self.ckpt_dir = config.ckpt_dir
		self.case = config.case
		os.makedirs(self.ckpt_dir, exist_ok=True)
		self.history = []

	def create_loader(self, X, y, is_train=True):
		dataset = TensorDataset(X, y)
		return DataLoader(
			dataset,
			batch_size=self.config.batch_size,
			shuffle=is_train
		)

	def fit(self, X_train, y_train, X_val, y_val):
		train_loader = self.create_loader(X_train, y_train, is_train=True)
		val_loader = self.create_loader(X_val, y_val, is_train=False)
		optimizer = torch.optim.AdamW(self.model.parameters(),
									  lr=self.config.lr, weight_decay=self.config.l2)
		scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.n_epochs)
		patience = 0
	
		pbar = tqdm(range(1, self.config.n_epochs+1), disable=not self.config.tqlm)
		for epoch in pbar:
			self.model.train()
			train_losses = []
			for xb, yb in train_loader:
				xb, yb = xb.to(self.device), yb.to(self.device)
				optimizer.zero_grad()
				pred = self.model(xb)
				loss = F.mse_loss(pred, yb)
				if torch.isnan(loss) or torch.isinf(loss):
					print(f"Loss nan/inf at epoch {epoch}")
					return
				loss.backward()
				nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
				optimizer.step()
				train_losses.append(loss.item())
			scheduler.step()
			avg_train_loss = np.mean(train_losses)
	
			# 每 50 個 epoch 才驗證 (eval) 並 early stop
			if epoch % 10 == 0:
				self.model.eval()
				val_losses = []
				with torch.no_grad():
					for xb, yb in val_loader:
						val_pred = self.model(xb.to(self.device))
						vloss = F.mse_loss(val_pred, yb.to(self.device))
						val_losses.append(vloss.item())
				avg_val_loss = np.mean(val_losses)
				self.history.append((avg_train_loss, avg_val_loss))
				pbar.set_postfix({"train_loss": f"{avg_train_loss:.5f}", "val_loss": f"{avg_val_loss:.5f}"})
				# 只在val loss有新低時儲存
				if avg_val_loss < self.best_val_loss:
					self.best_val_loss = avg_val_loss
					self.best_epoch = epoch
					self.model.save_checkpoint(
						os.path.join(self.ckpt_dir, self.case + "_best_model.pt"),
						optimizer, scheduler, epoch, avg_val_loss
					)
					patience = 0
				else:
					patience += 1
				if patience > self.config.early_stopping_patience:
					print(f"Early stopping at epoch {epoch} (best val loss: {self.best_val_loss:.6f})")
					break
			else:
				self.history.append((avg_train_loss, None))
				pbar.set_postfix({"train_loss": f"{avg_train_loss:.5f}", "val_loss": "N/A"})
	
	def predict(self, X, y=None):
		loader = self.create_loader(X, y if y is not None else torch.zeros(len(X)), is_train=False)
		self.model.eval()
		preds = []
		with torch.no_grad():
			for xb, _ in loader:
				pred = self.model(xb.to(self.device))
				preds.append(pred.cpu())
		return torch.cat(preds).numpy()
	
	def load_best_model(self, path=None):
		path = path or os.path.join(self.ckpt_dir, self.case  + "_best_model.pt")
		self.model, _ = NAM.load_checkpoint(path, device=self.device)
		self.model.to(self.device)
		
def main():

	parser = argparse.ArgumentParser(description="case")
	parser.add_argument('--case', default='wine', type=str, help='type of the data to use')
	args = parser.parse_args()
	
	case = args.case
	path = '../../Dataset/Real-Data-Application/'

	Train = pd.read_csv(path+case+'_Train.csv')
	Valid = pd.read_csv(path+case+'_Valid.csv')
	Test = pd.read_csv(path+case+'_Test.csv')

	y_train = Train['y']
	X_train = Train.iloc[:, :-1]
	y_val = Valid['y']
	X_val = Valid.iloc[:, :-1]
	y_test = Test['y']
	X_test = Test.iloc[:, :-1]
	
	X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)
	y_train = torch.tensor(y_train, dtype=torch.float32)
	X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)
	y_val = torch.tensor(y_val, dtype=torch.float32)
	X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)
	y_test = torch.tensor(y_test, dtype=torch.float32)
	
	scaler = StandardScaler()
	
	X_train = scaler.fit_transform(X_train.cpu().numpy())
	X_val = scaler.transform(X_val.cpu().numpy())
	X_test = scaler.transform(X_test.cpu().numpy())
	
	X_train = torch.tensor(X_train, dtype=torch.float32)
	X_val = torch.tensor(X_val, dtype=torch.float32)
	X_test = torch.tensor(X_test, dtype=torch.float32)
	
	
	if case == 'bike':
		interaction = [[3, 6],[3, 8],[1, 3],[3, 10],[3, 7],[0, 3],[2, 3],[6, 10],[8, 9],[6, 9]]
		interaction = [list(set(element for sublist in interaction for element in sublist))]
	elif case == 'wine':
		interaction = [[3, 7],[1, 7],[5, 6],[2, 5],[1, 4],[0, 7],[4, 5]]
		interaction = [list(set(element for sublist in interaction for element in sublist))]
		
	elif case == 'ca':
		interaction = [[6, 7],[0, 5],[0, 1],[3, 5]]
		interaction = [list(set(element for sublist in interaction for element in sublist))]

	elif case == 'fico':
		interaction = [[15, 18],[4, 11],[6, 15],[4, 5],[1, 3]]
		interaction = [list(set(element for sublist in interaction for element in sublist))]

	else:
		pass
		
	index_list = [[i] for i in range(X_train.size()[1])] + interaction
	
	config = NAMConfig(
		input_dim=X_train.size()[1],
		index_list=index_list,
		hidden_sizes=[128, 64, 32, 16],
		use_exu=False,
		lr = 1e-2,
		dropout=0.1,
		use_batchnorm=True,
		batch_size=2048,
		n_epochs=1000,
		tqlm=True,
		case = case
	)
	trainer = NAMTrainer(config)
	start_time = time.time()
	trainer.fit(X_train, y_train, X_val, y_val)
	trainer.load_best_model()
	test_pred = trainer.predict(X_test)
	end_time = time.time()
	elapsed_time = end_time - start_time
	print('Case: ', case)
	print('Time: ', elapsed_time)
	print("Testing RMSE:", np.sqrt(np.mean((test_pred - y_test.numpy()) ** 2)))

if __name__ == "__main__":
    main()

