import copy
import json
import logging
import math
import os
import time
from typing import Dict, List, Optional, Tuple
import pickle

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import savgol_filter

from surrogate_models.ours import Ours

class OurAlgorithm:

    def __init__(
        self,
        init_index: int,
        mean: bool,
        eps: float,
        U,
        y_0: float,    
        config_ckpt: dict,
        model_ckpt: str,
        hp_candidates: np.ndarray,        
        seed: int = 11,
        max_benchmark_epochs: int = 52,
        total_budget: int = 500,
        device: str = None,
        dataset_name: str = 'unknown',
        output_path: str = '.',
        benchmark = None
    ):

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)        

        self.init_index = init_index
        self.U = U
        self.eps = eps
        self.benchmark = benchmark 
        self.mean = mean

        if device is None:
            self.dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        else:
            self.dev = torch.device(device)

        self.hp_candidates = hp_candidates
        self.seed = seed
        self.max_benchmark_epochs = max_benchmark_epochs
        self.output_path = output_path
        self.dataset_name = dataset_name
        self.t_0 = torch.FloatTensor([0.0]).to(self.dev)
        self.y_0 = torch.FloatTensor([y_0]).to(self.dev)

        # the keys will be hyperparameter indices while the value
        # will be a list with all the budgets evaluated for examples
        # and with all performances for the performances
        self.num_hps = self.hp_candidates.shape[0]
        self.budget_spent = 0
        self.total_budget = total_budget
        self.performances = dict()        
        self.num_mc_samples = 1000
        self.max_score = 0.
        self.best_utility = []
        self.probs = []
        self.hps = []

        if config_ckpt is None:            
            config = {
                'dim_x': self.hp_candidates.shape[1],
                'd_output': 1000,
                'd_model': 512,
                'nlayers': 12,
                'dropout': 0.2
            }
        else:
            with open(config_ckpt, "r") as fb:
                config = json.load(fb)
            config['dim_x'] = self.hp_candidates.shape[1]
        
        self.model = Ours(
            model_ckpt,
            config,
            self.dev
        )

        self.initialized = False
        
        xt, yt = [], []
        for hp_index in range(self.num_hps):
            xt.append(self.hp_candidates[hp_index])
            yt.append(self.benchmark.get_curve(hp_index, max_benchmark_epochs))

        self.xt = torch.FloatTensor(xt).to(self.dev) # [num_hps, dim_x]
        tt = [ time_index+1 for time_index in range(self.max_benchmark_epochs) ]
        self.tt = torch.FloatTensor(tt)[:, None].to(self.dev) / self.max_benchmark_epochs # [max_benchmark_epochs, 1]
        self.yt = torch.FloatTensor(yt).to(self.dev) # [num_hps, max_benchmark_epochs]

    def _prepare_train_dataset(self) -> Dict[str, torch.Tensor]:        
        context_indices = []
        if self.initialized:            
            xc, tc, yc = [], [], []
            for hp_index in self.performances:
                hp_candidate = self.hp_candidates[hp_index]
                performances = self.performances[hp_index]                

                for time_index, performance in enumerate(performances):
                    xc.append(hp_candidate) # [dim_x]
                    tc.append(time_index+1) # []
                    yc.append(performance) # []
                    context_indices.append([hp_index, time_index])

            xc = torch.FloatTensor(xc).to(self.dev) # [num_context, dim_x]
            tc = torch.FloatTensor(tc)[:, None].to(self.dev) / self.max_benchmark_epochs # [num_context, 1]
            yc = torch.FloatTensor(yc)[:, None].to(self.dev) # [num_context, 1]

        else:
            xc, tc, yc = None, None, None

        data = {
            'context_indices': context_indices,
            't_0': self.t_0,
            'y_0': self.y_0,
            'xc': xc,
            'tc': tc,
            'yc': yc
        }
            
        return data
    
    def _prepare_test_dataset(self) -> Dict[str, torch.Tensor]:

        data = {
            'xt': self.xt,
            'tt': self.tt,
            'yt': self.yt,
        }

        return data

    def suggest(self):
        if not self.initialized and self.init_index is not None:
            best_index, stop_sign = self.init_index, False
        else:
            best_index, stop_sign = self.hpo()

        if best_index in self.performances:
            max_t = len(self.performances[best_index])
            next_t = max_t + 1
        else:
            next_t = 1

        #print(best_index, next_t, stop_sign)

        # exhausted hpo budget, finish.
        if self.budget_spent > self.total_budget:
            exit(0)

        return best_index, next_t, stop_sign

    def observe(
        self,
        hp_index: int,
        t: int,
        score: float
    ):        
        self.budget_spent += 1
        self.hps.append(hp_index)

        if score > self.max_score:
            self.max_score = score

        if hp_index in self.performances:
            self.performances[hp_index].append(score)
        else:            
            self.performances[hp_index] = [score]

        assert len(self.performances[hp_index]) == t

        self.log_performances()
        self.initialized = True

    @ torch.no_grad()
    def hpo(
        self
    ):  
        if self.mean:
            sampled_graphs, _ = self.model.predict_pipeline(
                train_data=self._prepare_train_dataset(),
                test_data=self._prepare_test_dataset(),
                num_mc_samples=self.num_mc_samples*5
            ) # [num_hps, num_mc_samples, max_benchmark_epochs]
            sampled_graphs = sampled_graphs.reshape(
                self.num_hps, 5, self.num_mc_samples, self.max_benchmark_epochs)
            sampled_graphs = sampled_graphs.mean(dim=1)
        else:
            sampled_graphs, _ = self.model.predict_pipeline(
                train_data=self._prepare_train_dataset(),
                test_data=self._prepare_test_dataset(),
                num_mc_samples=self.num_mc_samples
            ) # [num_hps, num_mc_samples, max_benchmark_epochs]

        present_utility = self.U(self.budget_spent, self.max_score)
        best_acq = -1000.

        for hp_index in range(self.num_hps):            
            sampled_graph = sampled_graphs[hp_index] # [mc_samples, max_benchmark_epochs]            
            
            # we don't need the observed values
            if hp_index in self.performances:                
                max_t = len(self.performances[hp_index])                
                sampled_graph = sampled_graph[:, max_t:]

            # we only consider cumulative best performance so far
            sampled_graph[sampled_graph < self.max_score] = self.max_score
            sampled_graph = torch.cummax(sampled_graph, dim=-1)[0]

            # we can only observe values for possible budgets
            sampled_graph = sampled_graph[:, :(self.total_budget-self.budget_spent)]
            postfix_len = sampled_graph.shape[-1]

            # we can not observe anything then the future utility equals to the present utility
            if postfix_len == 0:
                acq = 0.
                utility = present_utility  
                prob = 0.             

            else:
                # num_mc_samples, postfix_len
                budget = torch.arange(
                    self.budget_spent+1, self.budget_spent+postfix_len+1
                )[None, :].repeat(self.num_mc_samples, 1).float().to(self.dev)                  
                u = self.U(budget, sampled_graph) # num_mc_samples, postfix_len
                # acq
                acq = torch.max(torch.mean(F.relu(u - present_utility), dim=0), dim=0)[0].item()
                
                prob = torch.max(torch.mean((u > present_utility).float(), dim=0)).item()

            if acq > best_acq:
                if postfix_len > 0:
                    best_acq = acq
                    best_index = hp_index
                    best_prob = prob
                    stop_sign = False
                else:
                    # if postfix_len of best acq is zero, then should stop and find the second best one.
                    stop_sign = True

        stop_sign = stop_sign or (best_prob <= self.eps)
        self.probs.append(best_prob)

        return best_index, stop_sign

    def log_performances(self):
        if self.budget_spent == self.total_budget:
            with open(os.path.join(self.output_path, f'{self.budget_spent}_performances.json'), 'w') as fp:
                json.dump(self.performances, fp)
        with open(os.path.join(self.output_path, 'hps.json'), 'w') as fp:
            json.dump(self.hps, fp)
        with open(os.path.join(self.output_path, f'probs.json'), 'w') as fp:
            json.dump(self.probs, fp)