from typing import Dict
from dataclasses import asdict
import logging
import copy

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch import optim

from bbo.algorithms.np.transformer.transformer import Transformer
from bbo.algorithms.np.transformer import bar_distribution
from bbo.algorithms.np.transformer.definitions import (
    TransformerConfig,
    AcqfConfig,
    PretrainConfig,
)
from bbo.algorithms.np.transformer.single_eval_pos_sampler import (
    weighted_single_eval_pos_sampler,
    uniform_single_eval_pos_sampler,
)
from bbo.datasets.base import SimpleDataset
from bbo.algorithms.utils import latin_hypercube, from_unit_cube, timer_wrapper
from tqdm import tqdm

log = logging.getLogger(__name__)


class TransformerOpt:
    def __init__(
        self,
        dim: int,
        lb: Tensor,
        ub: Tensor,
        name: str = 'TransformerOpt',
        n_init: int = 3,
        q: int = 1,
        transformer_config: dict = None,
        acqf_config: dict = None,
        pretrain_config: dict = None,
        device: str = 'cpu', # @TODO: pass device parameter
    ):
        assert lb.ndim == 1 and ub.ndim == 1
        assert lb.shape == ub.shape
        assert (lb < ub).all()
        self.dim = dim
        self.lb = torch.tensor(lb)
        self.ub = torch.tensor(ub)
        self.name = name
        self.n_init = n_init
        self.q = q
        self.transformer_config = TransformerConfig(**transformer_config) if transformer_config is not None else TransformerConfig()
        self.acqf_config = AcqfConfig(**acqf_config) if acqf_config is not None else AcqfConfig()
        self.pretrain_config = PretrainConfig(**pretrain_config)
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        log.info('Device: {}'.format(self.device))

        num_borders = 1000 # 10000
        self.criterion = bar_distribution.FullSupportBarDistribution(
            bar_distribution.get_bucket_limits(num_borders, full_range=(-3, 3)) # set ys if you have large data
        )
        self.transformer = Transformer(dim, self.criterion.num_bars, **asdict(self.transformer_config)).to(self.device)
        self.load_flag = False

        self.X = torch.zeros((0, dim))
        self.Y = torch.zeros((0, 1))

    def init(self):
        init_X = latin_hypercube(self.n_init, self.dim)
        init_X = from_unit_cube(init_X, self.lb.detach().cpu().numpy(), self.ub.detach().cpu().numpy())
        init_X = torch.from_numpy(init_X)
        return init_X

    def optimize_acqf(self, context_x, context_y, num_restart=1):
        best_y = torch.max(context_y)
        cand_x, cand_acqf_val = [], []
        for _ in range(num_restart):
            x = torch.rand(1, 1, self.dim)
            x.requires_grad_(True)
            optimizer = optim.Adam([x], lr=self.acqf_config.lr)
            for _ in range(self.acqf_config.epochs):
                logits = self.transformer.predict(context_x.detach(), context_y.detach(), x)
                acqf_val = self.criterion.ei(logits, best_y)
                loss = - acqf_val

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            cand_x.append(x)
            cand_acqf_val.append(acqf_val.item())
        idx = np.argmax(cand_acqf_val)
        best_x = cand_x[idx]
        
        return best_x.reshape(1, self.dim)
    
    def optimize_acqf_given_x(self, X, context_x, context_y, dims,  num_restart=1):
        best_y = torch.max(context_y)
        
        X = torch.tensor(X).reshape(-1,1,dims)
        logits = self.transformer.predict(context_x.detach(), context_y.detach(), X)
        acqf_val = self.criterion.ei(logits, best_y)
        
        idx = torch.argmax(acqf_val)
        best_x = X[idx,:]
        return best_x.reshape(1, self.dim)

    def load_pretrain(self, path):
        path = path if path.endswith('.pth') else path + '.pth'
        
        # Load the state dict on a specific device
        device = 'cuda:0'  # Choose an appropriate device
        state_dict = torch.load(path, map_location=device)
        
        self.transformer.load_state_dict(state_dict, strict=False)
        self.transformer.requires_grad_(False)
        
        # Ensure the model is on the correct device
        # self.transformer.to(device)
        
        # Log information
        logging.info(f'Loaded model from {path} onto {device}')

    def preprocess(self):
        self.lb = self.lb.to(self.device)
        self.ub = self.ub.to(self.device)
        train_X = (self.X - self.lb) / (self.ub - self.lb)
        Y_std = self.Y.std()
        if torch.isnan(Y_std):
            Y_std = 1e-6
        train_Y = (self.Y - self.Y.mean()) / (Y_std + 1e-6)
        train_X, train_Y = train_X.to(self.device), train_Y.to(self.device)
        
        return train_X, train_Y

    def postprocess(self, next_X):
        next_X = torch.tensor(next_X).to('cpu')
        next_X = self.lb + next_X * (self.ub - self.lb)
        return next_X

    def ask(self, model_path) -> Tensor:
        if self.load_flag == False:
            self.load_pretrain(model_path)
            self.load_flag = True

        if len(self.X) == 0:
            next_X = self.init()
        else:
            train_X, train_Y = self.preprocess()
            context_x = train_X.unsqueeze(1)
            context_y = train_Y.unsqueeze(-1)
            assert context_x.shape == (len(self.X), 1, self.dim)
            assert context_y.shape == (len(self.X), 1, 1)
            next_X = self.optimize_acqf(context_x, context_y)
            next_X = self.postprocess(next_X)
        return next_X

    def ask_by_mcts(self, X, model_path) -> Tensor:
        # assert X.shape[0] == 10000
        if self.load_flag == False:
            self.load_pretrain(model_path)
            self.load_flag = True
        train_X, train_Y = self.preprocess()
        context_x = train_X.unsqueeze(1)
        context_y = train_Y.unsqueeze(-1)
        assert context_x.shape == (len(self.X), 1, self.dim)
        assert context_y.shape == (len(self.X), 1, 1)
        
        X = (torch.tensor(X) - self.lb) / (self.ub - self.lb)
        next_X = self.optimize_acqf_given_x(X, context_x, context_y, self.dim)
        next_X = self.postprocess(next_X)
        return next_X
            
    def tell(self, X: Tensor, Y: Tensor) -> Tensor:
        X, Y = torch.tensor(X).to(self.X), torch.tensor(Y).to(self.Y)
        self.X = torch.vstack((self.X, X))
        self.Y = torch.vstack((self.Y, Y))

    def train(self, train_id2dataset: Dict[str, SimpleDataset], val_id2dataset: Dict[str, SimpleDataset]=None):
        """
        Assuming dataset is normalized. X in [0, 1] and Y is mean 0 and std 1
        """
        config = self.pretrain_config
        optim_config = config.optim_config
        device = self.device
        sampler = weighted_single_eval_pos_sampler
        optimizer = optim.AdamW(self.transformer.parameters(), lr=optim_config.lr, weight_decay=optim_config.weight_decay)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min((step+1)/optim_config.warmup_steps, 1))

        best_model, best_val = None, None

        bar = tqdm(range(config.epochs), desc="train progress")
        for epoch in bar:
            self.transformer.train()

            # sample the dataset
            dataset_idx = np.random.randint(low=0, high=len(train_id2dataset))
            key = list(train_id2dataset.keys())[dataset_idx]
            dataset = train_id2dataset[key]

            seq_len = np.random.randint(config.seq_len_range[0], config.seq_len_range[1])
            shifting = 2 * config.shifting * torch.rand(1).item() - config.shifting

            # sample the data
            batch_X, batch_Y = [], []
            for _ in range(config.bs):
                idx = np.random.choice(len(dataset), seq_len, replace=True)
                X, Y = dataset[idx]
                assert (X >= 0).all() and (X <= 1).all()
                Y = Y + shifting
                X, Y = X.unsqueeze(0), Y.unsqueeze(0)
                batch_X.append(X)
                batch_Y.append(Y)
            batch_X = torch.cat(batch_X)
            batch_Y = torch.cat(batch_Y)
            batch_X = batch_X.transpose(0, 1) # batch_first=False
            batch_Y = batch_Y.transpose(0, 1)
            assert batch_X.shape == (seq_len, config.bs, self.dim)
            assert batch_Y.shape == (seq_len, config.bs, 1)
            batch_X, batch_Y = batch_X.to(device).float(), batch_Y.to(device).float()

            # sample single_eval_pos
            single_eval_pos = sampler(len(batch_X))
            out = self.transformer(batch_X, batch_Y, single_eval_pos)
            targets = batch_Y[single_eval_pos: ].squeeze(-1)

            loss = self.criterion(out.reshape(-1, self.transformer.n_out), targets.flatten())
            loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            # eval
            if (epoch + 1) % config.eval_intervals == 0:
                self.transformer.eval()
                if val_id2dataset is not None:
                    id2dataset = val_id2dataset
                else:
                    id2dataset = train_id2dataset
                    log.info('Eval on training dataset')
                
                mean_mse_loss_list = []
                for dataset_id in id2dataset:
                    dataset = id2dataset[dataset_id]
                    mse_loss_list = []
                    for idx in np.split(np.arange(len(dataset)), np.arange(config.bs, len(dataset), config.bs)):
                        val_X, val_Y = dataset[idx] 
                        val_X, val_Y = val_X.float(), val_Y.float()
                        val_X, val_Y = val_X.unsqueeze(0), val_Y.unsqueeze(0)
                        val_X, val_Y = val_X.transpose(0, 1), val_Y.transpose(0, 1)
                        val_X, val_Y = val_X.to(device), val_Y.to(device)
                        single_eval_pos = len(val_X) - 1 # here we predict the last y

                        # inference
                        with torch.no_grad():
                            logits = self.transformer.predict(val_X[: single_eval_pos], val_Y[: single_eval_pos], val_X[single_eval_pos: ])
                            mean = self.criterion.mean(logits)
                        loss = nn.functional.mse_loss(mean, val_Y[single_eval_pos: ])
                        mse_loss_list.append(loss.item())

                    mse_loss = np.mean(mse_loss_list)
                    mean_mse_loss_list.append(mse_loss)
                    log.info('Epoch: {}, dataset id: {}, loss: {}'.format(epoch, dataset_id, mse_loss))
                
                mean_mse_loss = np.mean(mean_mse_loss_list)
                print('Epoch: {}, mean loss: {}'.format(epoch, mean_mse_loss))

                # record best model
                if best_val is None or best_val > mean_mse_loss:
                    best_model = copy.deepcopy(self.transformer)
                    best_val = mean_mse_loss

        return self.transformer, mean_mse_loss, best_model, best_val
