import os
import time
import torch
import pickle
import hdf5storage
import math

from tqdm import tqdm
from easydict import EasyDict
from scipy.special import comb
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

import core.trainer as trainer
from dataset.utils import Dataset4Unsupervise
from dataset.pretrain_seq_ap_dataset import SeqAPDataset
from models.final_model import PretrainBertGpt, Backbone2, Singlemlm, Singlentp
from utils.utils import write_log, SaveBestModel, save_pretrain_loss_plots, save_lr_plots


class TeacherForcingScheduler:
    def __init__(self, p_start=1.0, p_end=0.1, total_steps=50000):
        self.p_start, self.p_end, self.total_steps = p_start, p_end, max(1, total_steps)
    def prob(self, step):
        t = min(step / self.total_steps, 1.0)
        return float(self.p_start + (self.p_end - self.p_start) * t)


class LambdaCosineScheduler:
    def __init__(self, lam_start=0.0, lam_end=1.0, total_steps=50000):
        self.s, self.e, self.N = lam_start, lam_end, max(1, total_steps)
    def value(self, step):
        t = min(step / self.N, 1.0)
        return float(self.s + 0.5*(self.e - self.s) * (1 - math.cos(math.pi * t)))



def main(opt_train, opt_data, opt_model, opt_save, opt_eval, opt_mask, cfg_train):
    # Log param
    write_log([str(opt_train)+'\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
    write_log([str(opt_data)+'\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
    write_log([str(opt_model) + '\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
    write_log([str(opt_mask) + '\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
    
    # Get dataloader
    print("* Get dataloader")
    pretrain_dataloader = get_dataloader(opt_train, opt_data, opt_mask)

    # Initialize model and loss
    print("* Get model")
    model = Singlentp(opt_model).to(opt_train.device)
    if opt_train.optimizer == 'adam':
        optimizer = torch.optim.Adam(params=model.parameters(), lr=opt_train.lr_rate)
    elif opt_train.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(params=model.parameters(), lr=opt_train.lr_rate, betas=(0.9, 0.95), weight_decay=5e-2)
    elif opt_train.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params=model.parameters(), lr=opt_train.lr_rate, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt_train.t_max, eta_min=opt_train.lr_min)

    if opt_train.criterion == 'mse':
        criterion = torch.nn.MSELoss()
    elif opt_train.criterion == 'l1':
        criterion = torch.nn.L1Loss()

    # Train!!!
    print("* Train")
    train(pretrain_dataloader, model, criterion, optimizer, opt_train, opt_save, scheduler, cfg_train)

    return

def train(pretrain_dataloader, model, criterion, optimizer, opt_train, opt_save, scheduler, cfg_train):
    save_model = SaveBestModel(opt_save.checkpoints_dir)
    tf_sched  = TeacherForcingScheduler(p_start=1.0, p_end=0.2, total_steps=4900)  ## 5b5c是2500,4a是4900, 2a是12950
    lam_sched = LambdaCosineScheduler(lam_start=0.0, lam_end=1.0, total_steps=4900)

    pretrain_losses = list()
    lr_rates = list()
    global_step = 0
    for epoch_idx in tqdm(range(opt_train.epochs)):
        curr_lr_rate = optimizer.state_dict()['param_groups'][0]['lr']
        pretrain_epoch_loss, global_step = trainer.pretrain(model, opt_train.device, pretrain_dataloader, criterion, optimizer, tf_sched, lam_sched, cfg_train, global_step)

        model_name = model.name
        write_log([str(pretrain_epoch_loss)], model_name, log_dir=opt_save.log_dir, log_type='pretrain_loss')

        print("Epoch:{}/{} AVG Pretraining Loss:{:.3f} Learning Rate:{:6f}".format(epoch_idx+1, opt_train.epochs, pretrain_epoch_loss, curr_lr_rate))
        
        curr_lr_rate = optimizer.state_dict()['param_groups'][0]['lr']
        curr_pretrain_loss = pretrain_epoch_loss
        
        save_model(epoch_idx, curr_pretrain_loss, model, optimizer, criterion)

        pretrain_losses.append(curr_pretrain_loss)
        lr_rates.append(curr_lr_rate)

        scheduler.step()
    
    save_pretrain_loss_plots(opt_save.checkpoints_dir, pretrain_losses)
    save_lr_plots(opt_save.checkpoints_dir, lr_rates)

    return 

def get_dataloader(opt_train, opt_data, opt_mask):
    if opt_data.agc_calibrate == False:
        data_suffix = 'csi'
    else:
        data_suffix = 'agc_caled_csi'
    seq_dataset = SeqAPDataset(phone_names=opt_data.phone_names, window_ap_num=opt_data.window_ap_num, seq_ap_num=opt_data.seq_ap_num, 
                               time_step=opt_data.time_step, time_window=opt_data.time_window, data_suffix=data_suffix)

    ap_coords, data = seq_dataset.get_data_labels(data_dim=opt_data.data_dim, is_normalize=opt_data.is_normalize)

    pretrain_dataset = Dataset4Unsupervise(ap_coords, data, opt_mask)
    pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=True, batch_size=opt_train.batch_size, num_workers=8)

    return pretrain_dataloader

if __name__ == "__main__":
    GPU_ID = "2"
    SEED = 2077
    SAVE_DIR = './pretrain_runs/4a_ntp'
    DATA_DIR = "./data/processed_pretrain/4a"
    if torch.cuda.is_available():
        torch.cuda.set_device(int(GPU_ID))


    default_train_cfg = {
        'detach_ntp_input': False,      # 会在 train_step 里按 warmup_detach_steps 动态改
        'noise_sigma': 0.01,           # 给 NTP 输入的小噪声（鲁棒性）
        'k_step_ratio': 0.10,          # 10% 批次做 rollout
        'k_max': 3,                    # 最多滚动 3 步
        'alpha_var_weight': 1.0,       # 不确定性权重强度
        'warmup_detach_steps': 1500   # 前 5k steps 只训 MLM->NTP（冻结 NTP 反传）# 5b5c是800, 4a是1500, 2a是4000
    }

    opt_eval = EasyDict()
    opt_eval.MC_dropout = True
    opt_eval.num_sampling = 20

    opt_data = EasyDict()
    opt_data.train_data_size = 0.91
    opt_data.data_dim = 2
    opt_data.seed = SEED
    opt_data.is_normalize = True
    opt_data.agc_calibrate = True
    opt_data.seq_ap_num = 5                   
    opt_data.window_ap_num = opt_data.seq_ap_num + 1
    opt_data.ensemble_num = int(comb(opt_data.window_ap_num, opt_data.seq_ap_num))
    opt_data.time_step = 0.5
    opt_data.time_window = 0.5
    opt_data.construction_type = 'window_ap_num-{0}_seq_ap_num-{1}_time_step-{2}_time_window-{3}'.format(opt_data.window_ap_num, opt_data.seq_ap_num, opt_data.time_step, opt_data.time_window)
    opt_data.target_data_dir = os.path.join(DATA_DIR, opt_data.construction_type)
    opt_data.phone_names = [name for name in os.listdir(opt_data.target_data_dir) if os.path.isdir(os.path.join(opt_data.target_data_dir, name))]
    opt_data.save_data = True
    opt_data.data_type = 'covariance'
    opt_data.area = '4a'

    opt_mask = EasyDict()
    opt_mask.mask_frame_ratio = 0.4
    opt_mask.mask_ratio = 0.3
    opt_mask.mask_prob = 0.8
    opt_mask.replace_prob = 0.1
    opt_mask.max_gram = 3
    opt_mask.mask_alpha = 0.2

    opt_train = EasyDict()
    opt_train.seed = opt_data.seed
    opt_train.device = 'cuda:{0}'.format(GPU_ID) if torch.cuda.is_available() else 'cpu'
    opt_train.batch_size = 128
    opt_train.lr_rate = 4e-4
    opt_train.criterion = 'mse'
    opt_train.optimizer = 'adamW'
    opt_train.epochs = 100
    opt_train.t_max = opt_train.epochs
    opt_train.lr_min = 1e-5
    

    opt_model = EasyDict()
    opt_model.in_channels = opt_data.data_dim 
    opt_model.ant_num = 2
    opt_model.input_feature_num = opt_model.in_channels * opt_model.ant_num * opt_model.ant_num + opt_model.ant_num
    opt_model.dropout_p = 0.1
    opt_model.model_dim = 512
    opt_model.feedforward_dim = 4*opt_model.model_dim
    opt_model.n_layers = 6
    opt_model.n_heads = 8
    opt_model.seq_len = opt_data.seq_ap_num  # sample num of a single sequence
    opt_model.decoder_type = 'FC'
    opt_model.fc_hidden_num = 1024
    opt_model.fc_dropout_p = 0.5
    opt_model.device = opt_train.device
    opt_model.output_embed = False
    opt_model.return_cls = False
    phone_num = 'multi_phone' if len(opt_data.phone_names) > 1 else opt_data.phone_names[0]


    opt_save = EasyDict()
    opt_save.save_name = 'supervise-{0}-window_ap_num:{1}-seq_ap_num:{2}-datadim:{3}'.format(phone_num, opt_data.window_ap_num, opt_data.seq_ap_num, opt_data.data_dim)
    opt_save.save_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())                           # experiment name when train.py is ran
    opt_save.checkpoints_dir = os.path.join(SAVE_DIR, "{0}-{1}".format(opt_save.save_time, opt_save.save_name))  # models are saved here
    opt_save.results_dir = opt_save.checkpoints_dir
    opt_save.log_dir = opt_save.checkpoints_dir
    opt_save.load_dir = opt_save.checkpoints_dir
    opt_save.save_path = os.path.join(opt_save.checkpoints_dir, 'model_dict')
    print(opt_save.save_path)
    
    main(opt_train, opt_data, opt_model, opt_save, opt_eval, opt_mask, default_train_cfg)