from matplotlib import pyplot as plt
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.metrics import metric
import torch
import torch.nn as nn
import os
import time
import warnings
import numpy as np
from tqdm import tqdm
from utils.dtw_metric import dtw, accelerated_dtw
from ...ITS.ITS import ITS

warnings.filterwarnings('ignore')


class Exp_Long_Term_Forecast(Exp_Basic):
    def __init__(self, args):
        super(Exp_Long_Term_Forecast, self).__init__(args)

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader
    
    def enable_dropout_only(self, model):
        for m in model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train()

    def test(self, setting):
        test_data, test_loader = self._get_data(flag='test')
        self.device = self.args.device

        print('loading model')
        checkpoint_path = os.path.join('./checkpoints/' + self.args.task_name + '/' + self.args.model + '/' + setting, 'checkpoint.pth')
        state_dict = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(self.device)

        its = ITS(
        args=self.args,
        backward_checkpoints=self.args.backward_pretrain_model_path,
        device=self.args.device
    )

        n_samples = self.args.n_samples
        trues = []
        preds_segmentReason = []
        
        folder_path = './test_results/' + self.args.task_name + '/' + self.args.model + '/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        
        self.enable_dropout_only(self.model)
        with torch.no_grad():
            segment_bar = tqdm(test_loader, desc="ITS")
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(segment_bar):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                f_dim = -1 if self.args.features == 'MS' else 0
                
                batch_size = batch_x.shape[0]
                feat_dim = batch_x.shape[-1]
                segment_size = 96
                num_segments = (self.args.pred_len + segment_size - 1) // segment_size  
                
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                if self.args.use_amp:
                    with torch.amp.autocast('cuda'):
                        batch_x_expanded = batch_x.repeat(n_samples, 1, 1)
                        batch_x_mark_expanded = batch_x_mark.repeat(n_samples, 1, 1)
                        batch_y_mark_expanded = batch_y_mark.repeat(n_samples, 1, 1)
                        dec_inp_expanded = dec_inp.repeat(n_samples, 1, 1)
                        outputs_expanded = self.model(batch_x_expanded, batch_x_mark_expanded, dec_inp_expanded, batch_y_mark_expanded)
                else:
                    batch_x_expanded = batch_x.repeat(n_samples, 1, 1)
                    batch_x_mark_expanded = batch_x_mark.repeat(n_samples, 1, 1)
                    batch_y_mark_expanded = batch_y_mark.repeat(n_samples, 1, 1)
                    dec_inp_expanded = dec_inp.repeat(n_samples, 1, 1)
                    outputs_expanded = self.model(batch_x_expanded, batch_x_mark_expanded, dec_inp_expanded, batch_y_mark_expanded)

                outputs_expanded = outputs_expanded[:, -self.args.pred_len:, :]

                noauto_samples_tensor = outputs_expanded.reshape(n_samples, batch_size, self.args.pred_len, feat_dim)
                
                segment_indices = np.arange(num_segments)
                start_indices = segment_indices * segment_size
                end_indices = np.minimum(start_indices + segment_size, self.args.pred_len)
                current_segment_sizes = end_indices - start_indices
                output_starts = segment_indices * segment_size
                need_borrow = (current_segment_sizes < segment_size) & (segment_indices > 0)
                borrow_needed = np.where(need_borrow, segment_size - current_segment_sizes, 0)
                adjusted_starts = np.maximum(0, start_indices - borrow_needed)
                borrow_sizes = np.where(need_borrow, start_indices - adjusted_starts, 0)
                final_start_indices = np.where(need_borrow, adjusted_starts, start_indices)

                segment_corrected_full = torch.zeros((batch_size, self.args.pred_len, feat_dim), device=self.device)

                ref_seq_len = min(segment_size, batch_x.shape[1])
                orig_ref_seq = batch_x[:, -ref_seq_len:, :]
                current_ref = orig_ref_seq
                prev_ref = None

                result_slices = [slice(output_starts[i], output_starts[i] + current_segment_sizes[i]) for i in range(num_segments)]

                for segment_idx in range(num_segments):
                    start_idx = final_start_indices[segment_idx]
                    end_idx = end_indices[segment_idx]
                    borrow_size = borrow_sizes[segment_idx]
                    
                    segment_samples = noauto_samples_tensor[:, :, start_idx:end_idx, :]
                
                    if segment_idx == 0:
                        ref_seq = current_ref
                    else:
                        if borrow_size > 0:
                            if prev_ref is not None:
                                borrowed_part = prev_ref[:, -borrow_size:, :]
                                remaining_part = current_ref[:, :-borrow_size, :]
                                ref_seq = torch.cat([borrowed_part, remaining_part], dim=1)
                                if ref_seq.shape[1] > segment_size:
                                    ref_seq = ref_seq[:, -segment_size:, :]
                            else:
                                if orig_ref_seq.shape[1] >= borrow_size:
                                    borrowed_part = orig_ref_seq[:, -borrow_size:, :]
                                    remaining_part = current_ref[:, :-borrow_size, :]
                                    ref_seq = torch.cat([borrowed_part, remaining_part], dim=1)
                                else:
                                    ref_seq = current_ref
                        else:
                            ref_len = min(segment_samples.shape[2], current_ref.shape[1])
                            ref_seq = current_ref[:, -ref_len:, :]
                    
                    if ref_seq.shape[1] == 0:
                        ref_seq = orig_ref_seq[:, -1:, :] if orig_ref_seq.shape[1] > 0 else batch_x[:, -1:, :]
                    
                    segment_corrected = its.run_inference(self.args, segment_samples, ref_seq)
                    
                    result_slice = result_slices[segment_idx]
                    if borrow_size > 0:
                        segment_corrected_full[:, result_slice, :] = segment_corrected[:, borrow_size:borrow_size + current_segment_sizes[segment_idx], :]
                    else:
                        segment_corrected_full[:, result_slice, :] = segment_corrected[:, :current_segment_sizes[segment_idx], :]
                    
                    prev_ref = current_ref
                    current_ref = segment_corrected

                segment_corrected_cpu = segment_corrected_full.detach().cpu().numpy()

                if test_data.scale and self.args.inverse:
                    shape = batch_y.shape[-3:]
                    if segment_corrected_cpu.shape[-1] != shape[-1]:
                        segment_corrected_cpu = np.tile(segment_corrected_cpu, [1, 1, int(shape[-1] / segment_corrected_cpu.shape[-1])])
                    segment_corrected_cpu = test_data.inverse_transform(segment_corrected_cpu.reshape(-1, shape[-1])).reshape(segment_corrected_cpu.shape)
                
                segment_corrected_cpu = segment_corrected_cpu[:, :, f_dim:]
                preds_segmentReason.append(segment_corrected_cpu)
        trues = np.concatenate(trues, axis=0)

        preds_segmentReason = np.concatenate(preds_segmentReason, axis=0)
            
        print('========================')

        preds_segmentReason = preds_segmentReason.reshape(-1, preds_segmentReason.shape[-2], preds_segmentReason.shape[-1])
        mae_segment, mse_segment, rmse_segment, mape_segment, mspe_segment = metric(preds_segmentReason, trues)
        print('ITS - mse:{}, mae:{}'.format(mse_segment, mae_segment))

        f = open("result_long_term_forecast.txt", 'a')
        f.write(setting + "  \n")

        f.write('ITS - mse:{}, mae:{},'.format(mse_segment, mae_segment))
        f.write('\n')

        f.write('\n')
        f.close()

        return