import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
import math
from einops import rearrange
import torch.nn.functional as F
from scipy.stats import norm

import torch


plt.switch_backend('agg')


def adjust_learning_rate(optimizer, scheduler, epoch, args, printout=True):
    if args.lradj == 'type1':
        lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
    elif args.lradj == 'type2':
        lr_adjust = {
            2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
            10: 5e-7, 15: 1e-7, 20: 5e-8
        }
    elif args.lradj == 'type3':
        lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.9 ** ((epoch - 3) // 1))}
    elif args.lradj == 'constant':
        lr_adjust = {epoch: args.learning_rate}
    elif args.lradj == "cosine":
        lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))}
    elif args.lradj == 'OneCircle':
        lr_adjust = {epoch: scheduler.get_last_lr()[0]}

    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        if printout: print('Updating learning rate to {}'.format(lr))



class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta

    def __call__(self, val_loss, model, path):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
        print('Model Saved')
        self.val_loss_min = val_loss


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class StandardScaler():
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean


def visual(true, preds=None, name='./pic/test.pdf'):
    plt.figure()
    plt.plot(true, label='GroundTruth', linewidth=2)
    if preds is not None:
        plt.plot(preds, label='Prediction', linewidth=2)
    plt.legend()
    plt.savefig(name, bbox_inches='tight')


def adjustment(gt, pred):
    anomaly_state = False
    for i in range(len(gt)):
        if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
            anomaly_state = True
            for j in range(i, 0, -1):
                if gt[j] == 0:
                    break
                else:
                    if pred[j] == 0:
                        pred[j] = 1
            for j in range(i, len(gt)):
                if gt[j] == 0:
                    break
                else:
                    if pred[j] == 0:
                        pred[j] = 1
        elif gt[i] == 0:
            anomaly_state = False
        if anomaly_state:
            pred[i] = 1
    return gt, pred


def cal_accuracy(y_pred, y_true):
    return np.mean(y_pred == y_true)


def set_seed(seed=1):
    fix_seed = seed
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)
    torch.cuda.manual_seed_all(fix_seed)
    return None


def cal_relative_w(x):
    flattened_x = x.flatten()
    _, indices = torch.sort(flattened_x)
    ranks = torch.argsort(indices).float() + 1
    min_val = ranks.min()
    max_val = ranks.max()
    normalized_x = (ranks - min_val) / (max_val - min_val)
    normalized_x = normalized_x.view(x.shape)
    return normalized_x


def cal_rand(metric_list, metric_columns=['MSE','MAE'], setting=None):
    result = dict()
    print('Total Evaluation \n')
    eva_res_dict = {}
    for i, metric in enumerate(metric_columns):
        metri_avg = round(np.mean(metric_list[i]), 3)
        metri_std = round(np.std(metric_list[i]), 4)
        result[metric]=(metri_avg, metri_std)
        eva_res_dict[metric] = '{:.3f}±{:.4f}'.format(metri_avg, metri_std)
        print('{}:{:.3f}±{:.4f}'.format(metric, metri_avg, metri_std))


    if setting is not None:
        model_name = setting.split('_')[6]
        dataset_name = setting.split('_')[3]
        f = open("result_long_term_forecast_{}_{}.txt".format(model_name, dataset_name), 'a')
        f.write('Multiple_run_Results:'+setting + "  \n")

        f.write('{}:{:.3f}±{:.4f}, {}:{:.3f}±{:.4f}'.\
            format('MSE',result['MSE'][0],result['MSE'][1],\
                'MAE',result['MAE'][0],result['MAE'][1]))
        f.write('\n')
        f.write('\n')
        f.close()

    return eva_res_dict

def set_result_table(path, pred_lens=[96, 192, 336, 720], metrics=['MSE','MAE'], 
                 dataset=['ETTh1','ETTh2','ETTm1','ETTm2','electricity', 
                          'traffic', 'exchange_rate','weather']):
    if os.path.exists(path):
        pass

    else:
        pred_lens.append('Average')
        rows = (len(pred_lens))*len(metrics)
        index_0 = [m for m in metrics for _ in range(len(pred_lens))]
        index_1 = [str(p) for _ in range(len(metrics)) for p in pred_lens ]
        multi_index = pd.MultiIndex.from_arrays([index_0, index_1], names=['Metric', 'Pred_len'])
        result_table = pd.DataFrame(index=multi_index, columns=dataset)
        result_table.to_csv(path, index=True)
    
def send_table(eva_res_dict, path, data_name, pred_len, metrics=['MSE','MAE']):
    df = pd.read_csv(path, header=[0], index_col=[0, 1])
    pred_lens = [str(x) for x in df.index.levels[1] if x!='Average']
    for metric in metrics:
        df.loc[pd.IndexSlice[metric, str(pred_len)], data_name] = eva_res_dict[metric]
        res_values = df.loc[pd.IndexSlice[metric, pred_lens], data_name].apply(lambda x:x.split('±')[0] if isinstance(x, str) and '±' in x else x)
        average_ = res_values.astype('float').mean()
        df.loc[pd.IndexSlice[metric, 'Average'], data_name] = average_ 
        
    df.to_csv(path, index=True)




class sax_transform:

    def __init__(self, n_segments, alphabet_size, tokenizer, device='cuda', output_token=True):

        self.n_segments = n_segments
        self.alphabet_size = alphabet_size
        self.device = device
        self.output_token = output_token

        if self.output_token:
            self.token_dict = self.sax_tokenizer(num=alphabet_size, tokenizer=tokenizer)[1]




    def sax_tokenizer(self, num, tokenizer, device='cuda'):
        num_to_token_ids = {}
        max_tokens_per_num = 0
        if num<27:
            for n in range(num+1):
                char = chr(ord('a')+n)
                tokens = tokenizer.tokenize(char)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                num_to_token_ids[n] = token_ids
                max_tokens_per_num = max(max_tokens_per_num, len(token_ids))
        else:
            for n in range(num + 1):
                tokens = tokenizer.tokenize(str(n))
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                num_to_token_ids[n] = token_ids
                max_tokens_per_num = max(max_tokens_per_num, len(token_ids))
        token_map_tensor = torch.full(
            (num + 1, max_tokens_per_num),
            tokenizer.pad_token_id,
            device=device
        )
        
        for n, ids in num_to_token_ids.items():
            token_map_tensor[n, :len(ids)] = torch.tensor(ids, device=device)
        
        return num_to_token_ids, token_map_tensor
    
    def add_special_tokens(self, token_ids, bos_token_id=0, eos_token_id=2):
        padded = F.pad(token_ids, (1, 1), value=0)
        padded[..., 0] = bos_token_id
        padded[..., -1] = eos_token_id
        return padded


    def transform(self, time_series, prompt=None):
        if not isinstance(time_series, torch.Tensor):
            time_series = torch.tensor(time_series, dtype=torch.float32, device=self.device)
        if time_series.dim() != 3:
            raise ValueError(f"输入必须是三维张量 [batch_size, variates_num, time_length], 但得到{time_series.dim()}维")
        
        B, M, L = time_series.shape
        
        if L % self.n_segments != 0:
            raise ValueError(f"时间长度{L}必须能被分段数{self.n_segments}整除")
        
        segment_length = L // self.n_segments
        
        segmented = rearrange(time_series, 'b m (n l) -> b m n l', n=self.n_segments, l=segment_length)
        
        paa_values = torch.mean(segmented, dim=-1)
        
        breakpoints = torch.tensor(
            norm.ppf(np.linspace(1/self.alphabet_size, 1-1/self.alphabet_size, self.alphabet_size-1)),
            dtype=torch.float32,
            device=self.device
        )
        
        sax_output = torch.bucketize(paa_values, breakpoints) 


        if self.output_token:
            token_id = torch.take(self.token_dict, sax_output)
            if prompt is not None:
                l = len(prompt)
                prompt = prompt.expand(B,M,l)
                token_id = torch.concat([prompt, token_id], dim=-1)
            token_id = self.add_special_tokens(token_id, bos_token_id=0, eos_token_id=2)
            return token_id
        return sax_output.half()
    

if __name__ == '__main__':

    batch_size = 2000
    variates_num = 300
    time_length = 1200

    time_series = torch.randn(batch_size, variates_num, time_length, device='cuda')
    n_segments = 100
    alphabet_size = 50
    sax_output = sax_transform(time_series, n_segments, alphabet_size, device='cuda')