import warnings

warnings.filterwarnings("ignore")

from transformers import logging

logging.set_verbosity_error()

import sys
import torch
import copy
from torch import optim
import math
import pandas as pd
from .base import LLMBayesOpt
from problems.data_processor import DataProcessor
from utils.configs import LoraFSPLaplaceConfig
import tqdm
from contextlib import nullcontext
from transformers import get_scheduler

from gpytorch.likelihoods import GaussianLikelihood


from bayesopt.surrogates.fsplaplace_utils.prior import (
    GPModel,
    optimize_prior_parameters
)
from bayesopt.surrogates.fsplaplace_utils.fsplaplace_model import FSPLaplace


from typing import *

from gpytorch.means import Mean
from gpytorch.kernels import Kernel
from gpytorch.kernels import ProductKernel, AdditiveKernel, ScaleKernel


class FSPLLALoRALLMBayesOpt(LLMBayesOpt):
    """
    The Laplace approx. is applied on the regression head and the LoRA weights.
    """

    def __init__(
        self,
        get_model: Callable[[], torch.nn.Module],
        training_set: List[pd.Series],
        data_processor: DataProcessor,
        candidate_X: List[pd.Series],
        mean_fn: Mean,
        kernel_fn: Kernel,
        bnn: FSPLaplace = None,
        laplace_config: LoraFSPLaplaceConfig = None,
        device: str = "cuda",
        dtype: str = "float64",
        append_eos: bool = True
    ):
        self.device = device
        self.dtype = dtype
        self.ptdtype = {
            "float64": torch.float64,
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }[dtype]
        self.ctx = (nullcontext())
        self.enable_grad_scaler = dtype in ["float16", "bfloat16"]
        self.append_eos = append_eos

        self.candidate_X = candidate_X
        self.kernel_fn = kernel_fn
        self.mean_fn = mean_fn

        super().__init__(
            get_model, training_set, data_processor, bnn, laplace_config, device
        )

    def train_model(self):
        del self.bnn
        cfg = self.laplace_config

        # Get train loader
        train_loader, self.label_mean, self.label_std = self.data_processor.get_dataloader(
            pd.DataFrame(self.training_set),
            batch_size=cfg.batch_size, 
            shuffle=True,
            append_eos=self.append_eos,
            standardize_y=True 
        )
        n_samples = len(train_loader.dataset) # type: ignore

        # Get loader of context points
        context_loader = self.data_processor.get_dataloader(
            pd.DataFrame(self.candidate_X),
            batch_size=cfg.batch_size, 
            shuffle=True,
            append_eos=self.append_eos
        )
        context_iter = iter(context_loader)

        # Initialize the GP likelihood
        likelihood = GaussianLikelihood()# .to(self.device).to(self.ptdtype)
        likelihood.noise = cfg.noise_var
        
        # Initialize the GP kernel
        train_X = torch.tensor(train_loader.dataset["features"]).to(self.device, self.ptdtype)
        train_Y = torch.tensor(train_loader.dataset["labels"]).to(self.device, self.ptdtype)
        feature_dim = train_X.shape[-1]
        self.init_kernel_params(self.kernel_fn, feature_dim)
        prior = GPModel(
            self.mean_fn, 
            self.kernel_fn,
            likelihood,
            train_X,
            train_Y
        ).to(self.device).to(self.ptdtype)

        # Fit prior
        prior, noise_var = optimize_prior_parameters(
            prior, 
            prior.likelihood, 
            train_X, 
            train_Y, 
            n_steps=cfg.prior_n_steps, 
            val_frequency=cfg.prior_val_frequency,
            verbose=False
        )
        prior.eval()
        prior.likelihood.eval()

        # Setup model        
        model = self.get_model().to(self.device)#.to(self.ptdtype)

        # Get parameters
        noise_var = torch.nn.parameter.Parameter(
            noise_var * torch.ones(1, dtype=self.ptdtype, device=self.device)
        )
        lora_params = [
            p for n, p in model.named_parameters() if p.requires_grad and "lora" in n
        ]
        head_params = [
            p for n, p in model.named_parameters() if p.requires_grad and "lora" not in n
        ]

        # Setup optimizers
        optimizer_lora = optim.AdamW(lora_params, lr=cfg.lr_lora)
        optimizer_head = optim.AdamW(head_params + [noise_var], lr=cfg.lr)
        
        # Setup training schedulers
        num_training_steps = cfg.n_epochs * len(train_loader)
        scheduler_lora = get_scheduler(
            name="linear",
            optimizer=optimizer_lora,
            num_warmup_steps=0,
            num_training_steps=num_training_steps,
        )
        scheduler_head = get_scheduler(
            name="cosine",
            optimizer=optimizer_head,
            num_warmup_steps=0,
            num_training_steps=num_training_steps,
        )
        
        # Scaler for low precision computation 
        scaler = torch.amp.GradScaler(self.device, enabled=self.enable_grad_scaler)
        #scaler = torch.amp.GradScaler("cuda", enabled=self.enable_grad_scaler)

        # Define loss function
        loss_func = torch.nn.GaussianNLLLoss(reduction="mean", full=True)

        # Early stopping
        best_epoch_loss = float("inf")
        no_improvement_count = 0
        best_net, best_noise_var = None, None

        # Joint lora-head training
        for epoch in tqdm.trange(
            cfg.n_epochs, position=1, leave=False, desc="[Training]", colour="blue", file=sys.stdout, disable=True
        ):
            epoch_loss, epoch_acc = 0., 0.
            for batch in train_loader:
                model.train()
                x, y = batch, batch["labels"].to(self.device, self.ptdtype, non_blocking=True)

                # Get context data
                try:
                    x_context = next(context_iter)
                    x_context_features = x_context["features"].to(self.device, self.ptdtype)
                except StopIteration:
                    context_iter = iter(context_loader)
                    x_context = next(context_iter)
                    x_context_features = x_context["features"].to(self.device, self.ptdtype)

                with self.ctx:
                    outputs = model(x)
                    output_context = model(x_context)
                    prior_mean, prior_cov = prior.prior_predictive(x_context_features, jitter=1e-10)
                    # print(outputs.shape, labels.shape); input()
                    
                    # Compute loss
                    delta = output_context.T - prior_mean
                    loss = loss_func(outputs, y, noise_var.expand(y.shape[0], -1))
                    loss += 0.5 * torch.vmap(
                        lambda _mu, _cov: torch.dot(_mu, torch.linalg.solve(_cov, _mu)), chunk_size=cfg.n_chunks
                    )(delta, prior_cov).sum() / n_samples

                scaler.scale(loss).backward()

                if cfg.grad_clip != 0.0:
                    scaler.unscale_(optimizer_lora)
                    torch.nn.utils.clip_grad_norm_(lora_params, cfg.grad_clip)

                scaler.step(optimizer_lora)
                scaler.step(optimizer_head)
                scaler.update()
                scheduler_lora.step()
                scheduler_head.step()
                optimizer_lora.zero_grad(set_to_none=True)
                optimizer_head.zero_grad(set_to_none=True)
                noise_var.data.clamp_(min=1e-6)
                epoch_loss += loss.item()
                epoch_acc += torch.square(outputs - y).sum().item() 

                if epoch % cfg.val_frequency == 0 or epoch == cfg.n_epochs - 1:
                    print(f"Epoch {epoch} - loss: {epoch_loss / n_samples} - acc: {epoch_acc / n_samples}", flush=True)
                
                #Early stopping
                if epoch_loss < best_epoch_loss:
                    best_epoch_loss = epoch_loss
                    best_net = copy.deepcopy(model.state_dict())
                    best_noise_var = noise_var.clone()
                    no_improvement_count = 0
                else:
                    no_improvement_count += 1
                    if no_improvement_count > cfg.early_stopping_patience:
                        print(f"Early stopping at epoch {epoch}", flush=True)
                        break

        if best_net is not None:
            model.load_state_dict(best_net)
            noise_var = best_noise_var.detach()

        for n, p in model.named_parameters():
            if "lora" in n:
                p.requires_grad = False

        optimizer_head = optim.AdamW(head_params, lr=cfg.lr) 
        scheduler_head = get_scheduler(
            name="cosine",
            optimizer=optimizer_head,
            num_warmup_steps=0,
            num_training_steps=num_training_steps,
        )

        print("Fitting the model", flush=True)
        no_improvement_count = 0
        for epoch in tqdm.trange(
            cfg.head_n_epochs, position=1, leave=False, desc="[Training]", colour="blue", file=sys.stdout, disable=True
        ):
            epoch_loss, epoch_acc = 0., 0.
            for batch in train_loader:
                model.train()
                x, y = batch, batch["labels"].to(self.device, non_blocking=True)
                # Get context data
                try:
                    x_context = next(context_iter)
                    x_context_features = x_context["features"].to(self.device)
                except StopIteration:
                    context_iter = iter(context_loader)
                    x_context = next(context_iter)
                    x_context_features = x_context["features"].to(self.device)

                with self.ctx:
                    outputs = model(x)
                    output_context = model(x_context)
                    prior_mean, prior_cov = prior.prior_predictive(x_context_features, jitter=1e-10)
                    
                    # Compute loss
                    delta = output_context.T - prior_mean
                    loss = loss_func(outputs, y, noise_var.expand(y.shape[0], -1))
                    loss += 0.5 * torch.vmap(
                        lambda _mu, _cov: torch.dot(_mu, torch.linalg.solve(_cov, _mu)), chunk_size=cfg.n_chunks
                    )(delta, prior_cov).sum() / n_samples
                
                scaler.scale(loss).backward()
                scaler.step(optimizer_head)
                scaler.update()
                scheduler_head.step()
                optimizer_head.zero_grad(set_to_none=True)
                epoch_loss += loss.item()
                epoch_acc += torch.square(outputs - y).sum().item() 

                #Early stopping
                if epoch_loss < best_epoch_loss:
                    best_epoch_loss = epoch_loss
                    best_net = copy.deepcopy(model.state_dict())
                    best_noise_var = noise_var.clone()
                    no_improvement_count = 0
                else:
                    no_improvement_count += 1
                    if no_improvement_count > cfg.early_stopping_patience:
                        print(f"Early stopping at epoch {epoch}", flush=True)
                        break

                if epoch % 50 == 0 or epoch == cfg.head_n_epochs-1:
                    print(f"Epoch {epoch} - loss: {epoch_loss / n_samples} - acc: {epoch_acc / n_samples}", flush=True)

        if best_net is not None:
            model.load_state_dict(best_net)
            noise_var = best_noise_var.detach()
                
        # So that it's considered by Laplace
        for n, p in model.named_parameters():
            if "lora" in n:
                p.requires_grad = True

        model.eval()

        sigma_noise = noise_var**0.5 * torch.ones(1, device=self.device, dtype=self.ptdtype)
        self.bnn = FSPLaplace(
            model,
            prior,
            likelihood="regression",
            sigma_noise=sigma_noise,
            params_sketch=cfg.params_sketch,
            params_sketch_dim=cfg.params_sketch_dim,
            max_rank=cfg.max_rank, 
            dtype=self.ptdtype
        )
        
        # Create new dataloader with larger batch size
        gpu_name = torch.cuda.get_device_name(0).lower()
        batch_size = 512 if "h100" in gpu_name else 128 # 2048 work on h100
        context_loader = self.data_processor.get_dataloader(
            pd.DataFrame(self.candidate_X),
            batch_size=batch_size, 
            shuffle=True,
            append_eos=self.append_eos
        )

        self.bnn.fit(
            train_loader,
            context_loader,
            n_chunks=cfg.n_chunks,
        )

    def posterior(self, data):
        f_mean, f_var = self.bnn(data)  # (B, 1) and (B, 1, 1)
        f_mean, f_var = f_mean.detach(), f_var.detach()
        # Rescale data
        f_mean = f_mean * self.label_std + self.label_mean
        f_var = (f_var.squeeze(-1) + self.bnn.sigma_noise**2) * self.label_std**2
        # Add observation noise
        return torch.distributions.Normal(f_mean, f_var)


    def condition_on_observations(self, obs):
        self.training_set.append(obs)
        del self.bnn

        return FSPLLALoRALLMBayesOpt(
            get_model=self.get_model,
            training_set=self.training_set,  # Modified
            data_processor=self.data_processor,
            candidate_X=self.candidate_X,
            mean_fn=self.mean_fn,
            kernel_fn=self.kernel_fn,
            bnn=None,  # Will be retrained
            laplace_config=self.laplace_config,
            device=self.device,
            dtype=self.dtype,
            append_eos=self.append_eos,
        )
    

    def init_kernel_params(self, kernel, feature_dim):
        # Check if the kernel is a composition kernel
        if isinstance(kernel, (ProductKernel, AdditiveKernel)):
            # Recursively initialize sub-kernels
            for sub_kernel in kernel.kernels:
                self.init_kernel_params(sub_kernel, feature_dim)
        elif isinstance(kernel, ScaleKernel):
            self.init_kernel_params(kernel.base_kernel, feature_dim)

        # For kernels with length scale, set it to sqrt(input_dim)
        if kernel.has_lengthscale:
            kernel.lengthscale = math.sqrt(feature_dim)

        # For kernels with outputscale, set it to 1
        if hasattr(kernel, 'raw_outputscale'):
            kernel.outputscale = 1.0
        
        # For kernels with variance, set it to 1
        if hasattr(kernel, 'raw_variance'):
            kernel.variance = 1.0
