import os
import math
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from utils import ExponentialMovingAverage
from data.load_dataset import load_dataset
from global_config import ROOT_DIRECTORY
import argparse
import numpy as np

import scipy.optimize as sopt
import torch.nn.functional as F


class TaskModelTrainer:
    """Trainer class for MNIST Diffusion Models."""

    def __init__(self, learning_rate, batch_size, training_epoch, task_model, loss_function, dataset_name, task_name,
                 train_data_loader, test_data_loader, device, model_name="autoencoder"):
        """
        Initialize the MNIST Diffusion Trainer with the given configuration.

        :param config: TrainingConfig instance containing all training parameters.
        """

        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.device = device
        self.training_epoch = training_epoch
        self.task_name = task_name
        self.dataset_name = dataset_name
        self.model_name = model_name
        self.base_path = os.path.join(ROOT_DIRECTORY, "results", self.dataset_name, self.task_name)
        os.makedirs(self.base_path, exist_ok=True)
        self.ckpt_path = os.path.join(self.base_path, model_name + ".pt")

        # Load datasets
        self.train_dataloader = train_data_loader
        self.test_dataloader = test_data_loader

        self.log_freq = 10

        # Initialize model
        self.model = task_model.to(self.device)
        # self.alpha = torch.tensor(0.0002, device=self.device, requires_grad=True)
        self.alpha = torch.tensor(0.01, device=self.device, requires_grad=True)

        # Initialize optimizer and scheduler
        # self.optimizer_alpha = AdamW([self.alpha], lr=self.learning_rate)
        self.optimizer = AdamW(list(self.model.parameters()) + [self.alpha], lr=self.learning_rate)
        self.loss_fn = loss_function  # nn.MSELoss(reduction='mean')

        # Initialize global step counter
        self.global_steps = 0

    def set_alpha(self, value):
        self.alpha = torch.tensor(value, device=self.device, requires_grad=True)
        self.optimizer = AdamW(list(self.model.parameters()) + [self.alpha], lr=self.learning_rate)

    def load_model(self, model_name=None):
        if model_name is not None:
            base_model_path = os.path.join(self.base_path, model_name + ".pt")
            self.load_checkpoint(base_model_path)
            # print("Load success from: ", base_model_path)
        else:
            self.load_checkpoint(self.ckpt_path)
            print("Load success from: ", self.ckpt_path)

    def load_checkpoint(self, checkpoint_path: str):
        """
        Load model and EMA weights from a checkpoint.

        :param checkpoint_path: Path to the checkpoint file.
        """
        if os.path.exists(checkpoint_path):
            ckpt = torch.load(checkpoint_path, map_location=self.device)
            if "model_ema" in ckpt:
                self.model_ema.load_state_dict(ckpt["model_ema"])
            if "model" in ckpt:
                self.model.load_state_dict(ckpt["model"])
            print(f"Loaded checkpoint from {checkpoint_path}.")


        else:
            print(f"Checkpoint not found at {checkpoint_path}.")
            raise FileNotFoundError

    def save_checkpoint(self, checkpoint_path: str):
        """
        Save model and EMA weights to a checkpoint.

        :param checkpoint_path: Path where the checkpoint will be saved.
        """
        ckpt = {
            "model": self.model.state_dict()
        }
        torch.save(ckpt, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}.")

    def erm(self, importance_train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name):
        self.learning_rate = learning_rate
        self.ckpt_path = os.path.join(self.base_path, model_name + ".pt")
        best_val_loss = float('inf')  # Track best validation loss
        best_model_path = None

        for epoch in range(1, epochs + 1):
            self.model.train()
            epoch_loss = 0.0

            for batch_idx, batch_triplet in enumerate(importance_train_dataloader):
                # images, weights, normalizing_constant = batch_triplet[0].to(self.device), batch_triplet[1].to(
                #     self.device), batch_triplet[2].to(self.device)
                images, weights = batch_triplet[0].to(self.device), batch_triplet[1].to(self.device)

                # Forward pass
                output = self.model(images)
                mse_loss = self.loss_fn(output["sample"], images)  # .mean(dim=1)#[1, 2, 3])
                loss = mse_loss.mean() + output["aux_loss"] * 0.000579254

                # Backward pass and optimization
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                # Logging
                if self.global_steps % self.log_freq == 0:
                    current_lr = self.optimizer.param_groups[0]['lr']
                    print(f"Epoch [{epoch}/{epochs}], "
                          f"Step [{batch_idx + 1}/{len(importance_train_dataloader)}], "
                          f"Loss: {loss.item():.5f}, LR: {current_lr:.6f}")

                epoch_loss += loss.item()
                self.global_steps += 1

            avg_epoch_loss = epoch_loss / len(importance_train_dataloader)
            print(f"Epoch [{epoch}/{epochs}] completed with average loss: {avg_epoch_loss:.5f}")

            # Validate model
            val_loss = self.validate_one_epoch(0.0, validation_dataloader)
            print(f"Validation Loss after Epoch [{epoch}]: {val_loss:.5f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = self.ckpt_path
                self.save_checkpoint(best_model_path)
                print(f"New best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

        print(f"Training complete. Best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

    def cvar(self, importance_train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name):
        self.learning_rate = learning_rate
        self.ckpt_path = os.path.join(self.base_path, model_name + ".pt")
        best_val_loss = float('inf')  # Track best validation loss
        best_model_path = None

        for epoch in range(1, epochs + 1):
            self.model.train()
            epoch_loss = 0.0

            for batch_idx, batch_triplet in enumerate(importance_train_dataloader):
                # images, weights, normalizing_constant = batch_triplet[0].to(self.device), batch_triplet[1].to(
                #     self.device), batch_triplet[2].to(self.device)
                images, weights = batch_triplet[0].to(self.device), batch_triplet[1].to(self.device)

                # Forward pass
                output = self.model(images)
                mse_loss = self.loss_fn(output["sample"], images)  # .mean(dim=1)#[1, 2, 3])
                loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) + output["aux_loss"] * 0.000579254

                cvar_loss = self.alpha + 1.0 / (1.0 - beta) * nn.functional.relu(loss - self.alpha) * weights
                cvar_loss = torch.mean(cvar_loss)

                # Backward pass and optimization
                cvar_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                # Logging
                if self.global_steps % self.log_freq == 0:
                    current_lr = self.optimizer.param_groups[0]['lr']
                    print(f"Epoch [{epoch}/{epochs}], "
                          f"Step [{batch_idx + 1}/{len(importance_train_dataloader)}], "
                          f"Loss: {cvar_loss.item():.5f}, LR: {current_lr:.6f}")

                epoch_loss += cvar_loss.item()
                self.global_steps += 1

            avg_epoch_loss = epoch_loss / len(importance_train_dataloader)
            print(f"Epoch [{epoch}/{epochs}] completed with average loss: {avg_epoch_loss:.5f}")

            # Validate model
            val_loss = self.validate_one_epoch(beta, validation_dataloader)
            print(f"Validation Loss after Epoch [{epoch}]: {val_loss:.5f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = self.ckpt_path
                self.save_checkpoint(best_model_path)
                print(f"New best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

        print(f"Training complete. Best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

    def cvar_doro(self, train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name):
        model = self.model
        alpha = 1 - beta
        eps = 0.0
        device = self.device

        self.learning_rate = learning_rate
        self.ckpt_path = os.path.join(self.base_path, model_name + ".pt")
        best_val_loss = float('inf')  # Track best validation loss
        best_model_path = None

        for epoch in range(1, epochs + 1):
            epoch_loss = 0.0
            model.train()
            gamma = eps + alpha * (1 - eps)
            for batch_idx, batch_triplet in enumerate(train_dataloader):
                images, weights = batch_triplet[0].to(self.device), batch_triplet[1].to(self.device)

                # Forward pass
                output = self.model(images)
                mse_loss = self.loss_fn(output["sample"], images)  # .mean(dim=1)#[1, 2, 3])
                loss = mse_loss.mean(dim=tuple(range(1, mse_loss.ndim))) + output["aux_loss"] * 0.000579254

                batch_size = len(images)
                n1 = int(gamma * batch_size)
                n2 = int(eps * batch_size)
                rk = torch.argsort(loss, descending=True)
                loss = loss[rk[n2:n1]].sum() / alpha / (batch_size - n2)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Validate model
            val_loss = self.validate_one_epoch(beta, validation_dataloader)
            print(f"Validation Loss after Epoch [{epoch}]: {val_loss:.5f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = self.ckpt_path
                self.save_checkpoint(best_model_path)
                print(f"New best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

        print(f"Training complete. Best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

    def chisq(self, train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name):
        model = self.model
        alpha = 1 - beta
        max_l = 10.
        C = math.sqrt(1 + (1 / alpha - 1) ** 2)

        self.learning_rate = learning_rate
        self.ckpt_path = os.path.join(self.base_path, model_name + ".pt")
        best_val_loss = float('inf')  # Track best validation loss
        best_model_path = None

        for epoch in range(1, epochs + 1):
            epoch_loss = 0.0
            model.train()
            for batch_idx, batch_triplet in enumerate(train_dataloader):
                images, weights = batch_triplet[0].to(self.device), batch_triplet[1].to(self.device)

                # Forward pass
                output = self.model(images)
                mse_loss = self.loss_fn(output["sample"], images)  # .mean(dim=1)#[1, 2, 3])
                loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) + output["aux_loss"] * 0.000579254

                foo = lambda eta: C * math.sqrt((F.relu(loss - eta) ** 2).mean().item()) + eta
                opt_eta = sopt.brent(foo, brack=(0, max_l))
                loss = C * torch.sqrt((F.relu(loss - opt_eta) ** 2).mean()) + opt_eta

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Validate model
            val_loss = self.validate_one_epoch(beta, validation_dataloader)
            print(f"Validation Loss after Epoch [{epoch}]: {val_loss:.5f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = self.ckpt_path
                self.save_checkpoint(best_model_path)
                print(f"New best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

        print(f"Training complete. Best model saved at {best_model_path} with validation loss: {best_val_loss:.5f}")

    def validate_one_epoch(self, beta, validation_dataloader):
        self.model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch_triplet in validation_dataloader:
                images, weights = batch_triplet[0].to(self.device), batch_triplet[1].to(self.device)
                output = self.model(images)
                mse_loss = nn.MSELoss()(output["sample"], images)
                loss = mse_loss + output["aux_loss"] * 0.000579254
                cvar_loss = self.alpha + 1.0 / (1.0 - beta) * nn.functional.relu(loss - self.alpha) * weights
                cvar_loss = torch.mean(cvar_loss)
                val_loss += cvar_loss.item()
        return val_loss / len(validation_dataloader)

    def train(self, algorithm_name, train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name):
        if algorithm_name == 'erm':
            self.erm(train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name)
        elif algorithm_name == 'is_cvar' or algorithm_name == 'cvar':
            self.cvar(train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name)
        elif algorithm_name == 'cvar_doro':
            self.cvar_doro(train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name)
        elif algorithm_name == 'chisq':
            self.chisq(train_dataloader, validation_dataloader, epochs, learning_rate, beta, model_name)
        else:
            raise NotImplementedError

    def test(self, dataloader, distribution_name):

        self.model.eval()
        pred = []
        inputs = []
        losses = []

        with torch.no_grad():
            for batch_idx, images in enumerate(dataloader):
                if isinstance(images, (list, tuple)):
                    images = images[0]
                # else:
                # images = images
                # images = (images - 50.0) / 50.0
                images = images.to(self.device)
                inputs.append(images)

                # Forward pass
                output = self.model(images)
                # pred.append(output)
                mse_loss = self.loss_fn(output["sample"], images)
                mse_loss = mse_loss.mean(dim=tuple(range(1, mse_loss.ndim)))  # torch.mean(mse_loss, dim = [1,2,3])
                losses.append(mse_loss)

        # pred = torch.concatenate(pred, dim=0)

        inputs = torch.concatenate(inputs, dim=0)
        losses = torch.stack(losses, dim=0)

        savepath = os.path.join(ROOT_DIRECTORY, "results", self.dataset_name,
                                "pred_" + distribution_name + "_" + self.model_name + ".npy")
        np.save(savepath, losses.detach().cpu().numpy())
        print(f"loss info saved to {savepath}")

        print(f"average loss: {torch.mean(losses):.5f}")
        return inputs, losses

