import os
import time
import torch
import pickle
import hdf5storage
import json
from PIL import Image, ImageDraw
import cv2
import pickle
from itertools import chain
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from easydict import EasyDict
from scipy.special import comb
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau

import core.trainer as trainer
import core.evaluator as evaluator
from dataset.utils import Dataset4Supervise, DatasetntpSupervise, Dataset5Supervise
from dataset.seq_ap_dataset import SeqAPDataset
import dataset.seq_ap_dataset
from models.final_model import PretrainBertGpt, Backbone1, Backbone3, Backbone2, Backbone4
from utils.utils import write_log, SaveBestModel, save_loss_plots, save_lr_plots, save_cdf_plots

class LoRALinear(nn.Module):
    """
    线性层 + LoRA(A,B) 增量。主权重被冻结，只训练 lora_A/lora_B
    out = xW + b + scaling * (dropout(x) @ A^T @ B^T)
    """
    def __init__(self, base_linear: nn.Linear, r=4, alpha=8, dropout=0.1):
        super().__init__()
        assert isinstance(base_linear, nn.Linear)
        self.in_features  = base_linear.in_features
        self.out_features = base_linear.out_features
        self.r = r
        self.scaling = alpha / float(r) if r > 0 else 1.0

        # 主干权重复制并冻结
        self.weight = nn.Parameter(base_linear.weight.data.clone(), requires_grad=False)
        self.bias   = None
        if base_linear.bias is not None:
            self.bias = nn.Parameter(base_linear.bias.data.clone(), requires_grad=False)

        if r > 0:
            # LoRA 低秩分解
            self.lora_A = nn.Parameter(torch.empty(r, self.in_features))
            self.lora_B = nn.Parameter(torch.empty(self.out_features, r))
            nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
            nn.init.zeros_(self.lora_B)
            self.dropout = nn.Dropout(dropout)
        else:
            self.lora_A = None
            self.lora_B = None
            self.dropout = nn.Identity()

    def forward(self, x):
        # base
        y = F.linear(x, self.weight, self.bias)
        # lora 增量
        if self.r > 0:
            y = y + self.scaling * (self.dropout(x) @ self.lora_A.T @ self.lora_B.T)
        return y


def inject_lora_to_transformer(backbone: nn.Module, r=8, alpha=16, dropout=0.1):
    """
    在 nn.TransformerEncoderLayer 的 linear1/linear2/self_attn.out_proj 注入 LoRA
    """
    replaced = 0
    for _, m in backbone.named_modules():
        if isinstance(m, nn.TransformerEncoderLayer):
            # FFN
            m.linear1 = LoRALinear(m.linear1, r=r, alpha=alpha, dropout=dropout)
            m.linear2 = LoRALinear(m.linear2, r=r, alpha=alpha, dropout=dropout)
            replaced += 2
            # 注意力输出投影
            m.self_attn.out_proj = LoRALinear(m.self_attn.out_proj, r=r, alpha=alpha, dropout=dropout)
            replaced += 1
    print(f"[LoRA] injected on {replaced} Linear modules (linear1/linear2/out_proj)")
    return backbone


# time_last =['0507','0508','0509','0510','0511','0512','0513','0514','0515','0516','0517',
#            '0518','0519','0520','0521','0522','0523','0524','0525','0526','0527','0528','0529','0530',
#            '0531','0601','0602','0603','0604','0605','0606','0717']
# time_last = ['1+ace2','honor60','honormagic4p','honorX10','mate10','meizu16s','pixel4','iphone11','iphone12','iphone13','iphone13mini',
#                   'iphone14','iphone14pro','iphone15pm','redmik20p','redmiK30SUltra','redmiK40','redminote12-turbo','xiaomi8se','xiaomi13-1','xiaomi13-2','xiaomi14']
# time_last = ['iphone13mini','iphone14']
time_last = ['total_4a']
area_last = ['4A']


for time_idx in time_last:
    for area_idx in area_last:
        DATA_DIR1 = os.path.join('./data/', 'part')
        DATA_DIR2 = os.path.join(DATA_DIR1, time_idx)
        DATA_DIR = os.path.join(DATA_DIR2, area_idx)

            # 新增路径存在性检查
        if not os.path.exists(DATA_DIR):
            print(f"路径不存在，跳过: {DATA_DIR}")
            continue  # 跳过当前循环

        SAVE_DIR1 = os.path.join('./finetune_unsupervised_ntp/4a/10%/', time_idx)
        SAVE_DIR = os.path.join(SAVE_DIR1, area_idx)

        dataset.seq_ap_dataset.PROCESSED_DATA_DIR = DATA_DIR

        GPU_ID = "0"
        SEED = 2077
        if torch.cuda.is_available():
            torch.cuda.set_device(int(GPU_ID))

        opt_eval = EasyDict()
        opt_eval.MC_dropout = True
        opt_eval.num_sampling = 20

        opt_data = EasyDict()
        opt_data.label_data_size = 0.9
        opt_data.train_data_size = 0.9999
        opt_data.data_dim = 2
        opt_data.seed = SEED
        opt_data.is_normalize = True
        opt_data.agc_calibrate = True
        opt_data.seq_ap_num = 5                   
        opt_data.window_ap_num = opt_data.seq_ap_num + 1
        opt_data.ensemble_num = int(comb(opt_data.window_ap_num, opt_data.seq_ap_num))
        opt_data.time_step = 0.5
        opt_data.time_window = 0.5
        opt_data.construction_type = 'window_ap_num-{0}_seq_ap_num-{1}_time_step-{2}_time_window-{3}'.format(opt_data.window_ap_num, opt_data.seq_ap_num, opt_data.time_step, opt_data.time_window)
        opt_data.target_data_dir = os.path.join(DATA_DIR, opt_data.construction_type)
        opt_data.phone_names = [name for name in os.listdir(opt_data.target_data_dir) if os.path.isdir(os.path.join(opt_data.target_data_dir, name))]
        opt_data.save_data = True
        opt_data.data_type = 'covariance'
        opt_data.json_dir = './document/4a/map_4a.json'

        opt_train = EasyDict()
        opt_train.seed = opt_data.seed
        opt_train.device = 'cuda:{0}'.format(GPU_ID) if torch.cuda.is_available() else 'cpu'
        opt_train.batch_size = 64
        opt_train.lr_rate = 5e-4
        opt_train.criterion = 'mse'
        opt_train.optimizer = 'adamW'
        opt_train.epochs = 100
        opt_train.t_max = opt_train.epochs
        opt_train.lr_min = 1e-5
        opt_train.patience = 5
        opt_train.factor = 0.9
        opt_train.freeze_epochs = 20   # 冻结阶段的 epoch 数
        opt_train.mlp_lr_rate = 1e-3      # 回归头学习率


        opt_model = EasyDict()
        opt_model.in_channels = opt_data.data_dim 
        opt_model.ant_num = 2
        opt_model.input_feature_num = opt_model.in_channels * opt_model.ant_num * opt_model.ant_num + opt_model.ant_num
        opt_model.dropout_p = 0.1
        opt_model.model_dim = 512
        opt_model.feedforward_dim = 4*opt_model.model_dim
        opt_model.n_layers = 6
        opt_model.n_heads = 8
        opt_model.seq_len = opt_data.seq_ap_num  # sample num of a single sequence
        opt_model.decoder_type = 'FC'
        opt_model.fc_hidden_num = 1024
        opt_model.fc_dropout_p = 0.5
        opt_model.device = opt_train.device
        opt_model.seq_ap_num = opt_data.seq_ap_num
        opt_model.ensemble_num = opt_model.seq_ap_num + 1
        opt_model.return_cls = True
        phone_num = 'multi_phone' if len(opt_data.phone_names) > 1 else opt_data.phone_names[0]

        opt_save = EasyDict()
        opt_save.save_name = 'supervise-{0}-window_ap_num:{1}-seq_ap_num:{2}-datadim:{3}'.format(phone_num, opt_data.window_ap_num, opt_data.seq_ap_num, opt_data.data_dim)
        opt_save.save_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())                           # experiment name when train.py is ran
        opt_save.checkpoints_dir = os.path.join(SAVE_DIR, "{0}-{1}".format(opt_save.save_time, opt_save.save_name))  # models are saved here
        opt_save.results_dir = opt_save.checkpoints_dir
        opt_save.log_dir = opt_save.checkpoints_dir
        opt_save.load_dir = opt_save.checkpoints_dir
        opt_save.save_path = os.path.join(opt_save.checkpoints_dir, 'model_dict')
        opt_save.save_finetune = os.path.join('./finetune_unsupervised_ntp/4a/10%/', "2025-09-16-21:56:09-supervise-multi_phone-window_ap_num:6-seq_ap_num:5-datadim:2")
        print(opt_save.save_path)
    

        # load data
        def get_dataloader(opt_train, opt_data):
            if opt_data.agc_calibrate == False:
                data_suffix = 'csi'
            else:
                data_suffix = 'agc_caled_csi'
            seq_dataset = SeqAPDataset(phone_names=opt_data.phone_names, window_ap_num=opt_data.window_ap_num, seq_ap_num=opt_data.seq_ap_num, 
                                    time_step=opt_data.time_step, time_window=opt_data.time_window, data_suffix=data_suffix)

            test_ap_coords, test_data, test_labels = seq_dataset.get_data_labels(data_dim=opt_data.data_dim, is_normalize=opt_data.is_normalize)
            print("test_ap_coords, test_data, test_labels",  test_ap_coords.shape,  test_data.shape,  test_labels.shape)

            test_dataset = DatasetntpSupervise(test_ap_coords,test_data, test_labels)
            test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=opt_train.batch_size, num_workers=4)

            return test_dataloader
        
        ## 加载参数
        write_log([str(opt_train)+'\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
        write_log([str(opt_data)+'\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
        write_log([str(opt_model) + '\n'], "exp", log_dir=opt_save.log_dir, log_type='param')
        write_log([str(opt_save) + '\n'], "exp", log_dir=opt_save.log_dir, log_type='param')

        test_dataloader = get_dataloader(opt_train, opt_data)


        if opt_train.criterion == 'mse':
            criterion = torch.nn.MSELoss()
        elif opt_train.criterion == 'l1':
            criterion = torch.nn.SmoothL1Loss()

        model = Backbone1(opt_model).to(opt_train.device)
        # model=inject_lora_to_transformer(model, r=4, alpha=8, dropout=0.1)
        # model = model.to(opt_train.device)


        # Load best model
        finetune_model_path = os.path.join(opt_save.save_finetune, 'best_model.pth')
        model.load_state_dict(torch.load(finetune_model_path, map_location=opt_train.device)['model_state_dict'])

        test_loss, result = evaluator.evaluate_MC_drop_out(opt_eval, model, opt_train.device, test_dataloader, criterion)

        print("Test losses:", test_loss)

        save_path = f"{opt_save.checkpoints_dir}/unsupervise_test_result.mat"
        save_cdf_plots(opt_save.checkpoints_dir, result["output_locations_fusion_error"])

        print(f"Saving file to {save_path}")


        hdf5storage.savemat(save_path,
                mdict=result, 
                appendmat=True, 
                format='7.3',
                truncate_existing=True)
