import os
import time
import warnings
import torch
import numpy as np
import torch.nn as nn
import torch.distributed as dist
from matplotlib import pyplot as plt
from torch import optim
from torch.nn.parallel import DistributedDataParallel as DDP
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, LargeScheduler
from utils.metrics import metric
from ...ITS.ITS import ITS
from tqdm import tqdm

warnings.filterwarnings('ignore')


class Exp_Forecast(Exp_Basic):
    def __init__(self, args):
        super(Exp_Forecast, self).__init__(args)
        
    def _build_model(self):
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = self.model_dict[self.args.model].Model(self.args)
            model = DDP(model.cuda(), device_ids=[self.args.local_rank], find_unused_parameters=True)
        else:
            self.args.device = self.device
            model = self.model_dict[self.args.model].Model(self.args)
        return model
    
    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader
    
    def enable_dropout_only(self, model):
        for m in model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train() 

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        # self.device = self.args.device
        
        n_samples = args.n_samples
        
        print("info:", self.args.seq_len, self.args.pred_len, self.args.label_len, self.args.output_len)
        print("loading model from {}".format(self.args.pretrain_model_path))
        print("loading backward_model from {}".format(self.args.backward_pretrain_model_path))


        if self.args.adaptation:
            load_item = torch.load(self.args.pretrain_model_path, map_location=self.device)
            self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False)

        self.model = self.model.to(self.device)

        its = ITS(
            args=self.args,
            backward_checkpoints=self.args.backward_pretrain_model_path,
            device=self.device
        )
            
        trues = []
        preds_ITSReason = []

        folder_path = './test_results/' + setting + '/' + self.args.data_path + '/'
        if not os.path.exists(folder_path) and int(os.environ.get("LOCAL_RANK", "0")) == 0:
            os.makedirs(folder_path)
        
        print('Model parameters: ', sum(param.numel() for param in self.model.parameters()))
        
        self.enable_dropout_only(self.model)  
        with torch.no_grad():
            test_bar = tqdm(test_loader, desc="ITS")
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_bar):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                        
                inference_steps = self.args.output_len // self.args.pred_len
                dis = self.args.output_len - inference_steps * self.args.pred_len
                if dis != 0:
                    inference_steps += 1
                
                current_batch_x = batch_x
                
                ITSReason_pred = []
                batch_size = batch_x.shape[0]
                
                for j in range(inference_steps):
                    if len(ITSReason_pred) != 0:
                        current_batch_x = torch.cat([current_batch_x[:, self.args.pred_len:, :], ITSReason_pred[-1]], dim=1)
                        tmp = batch_y_mark[:, j - 1:j, :]
                        batch_x_mark = torch.cat([batch_x_mark[:, 1:, :], tmp], dim=1)

                    if self.args.use_amp:
                        with torch.amp.autocast('cuda'): 
                            expanded_batch_x = current_batch_x.repeat(n_samples, 1, 1)
                            expanded_batch_x_mark = batch_x_mark.repeat(n_samples, 1, 1) 
                            expanded_batch_y_mark = batch_y_mark.repeat(n_samples, 1, 1)
                            expanded_outputs = self.model(expanded_batch_x, expanded_batch_x_mark, batch_y, expanded_batch_y_mark)
                    else:
                        expanded_batch_x = current_batch_x.repeat(n_samples, 1, 1)
                        expanded_batch_x_mark = batch_x_mark.repeat(n_samples, 1, 1) 
                        expanded_batch_y_mark = batch_y_mark.repeat(n_samples, 1, 1)
                        expanded_outputs = self.model(expanded_batch_x, expanded_batch_x_mark, batch_y, expanded_batch_y_mark)
                    
                    expanded_outputs = expanded_outputs[:, -self.args.pred_len:, :]

                    # [n_samples, batch, pred_len, feat_dim]
                    step_samples_tensor = expanded_outputs.reshape(n_samples, batch_size, self.args.pred_len, -1)
                    
                    fusion_candidate = its.run_inference(self.args, step_samples_tensor, current_batch_x[:, -self.args.pred_len:, :])

                    f_dim = -1 if self.args.features == 'MS' else 0
                    ITSReason_pred.append(fusion_candidate[:, -self.args.pred_len:, :])
            
                ITSReason_pred = torch.cat(ITSReason_pred, dim=1)
                if dis != 0:
                    ITSReason_pred = ITSReason_pred[:, :-self.args.pred_len+dis, :]

                if not self.args.nonautoregressive:
                    batch_y = batch_y[:, self.args.label_len:self.args.label_len + self.args.output_len, :].to(self.device)
                else:
                    batch_y = batch_y[:, :self.args.output_len, :].to(self.device)
                
                outputs = ITSReason_pred.detach().cpu()
                
                if test_data.scale and self.args.inverse:
                    shape = outputs.shape
                    outputs = test_data.inverse_transform(outputs.squeeze(0)).reshape(shape)
            
                outputs = outputs[:, :, f_dim:]
                
                pred = outputs
                
                preds_ITSReason.append(pred)
        
        trues_tensor = torch.cat(trues, dim=0) 
        
        preds_ITSReason_tensor = torch.cat(preds_ITSReason, dim=0)
       
        trues_np = trues_tensor.numpy()
        
        f_dim = -1 if self.args.features == 'MS' else 0
        if f_dim == -1:
            trues_np = trues_np[:, :, -1]
            preds_original_np = preds_original_np[:, :, -1]
        
        preds_ITSReason_np = preds_ITSReason_tensor.numpy()
        if f_dim == -1:
            preds_ITSReason_np = preds_ITSReason_np[:, :, -1]
        
        ITSReason_mae, ITSReason_mse, ITSReason_rmse, ITSReason_mape, ITSReason_mspe, ITSReason_smape = metric(preds_ITSReason_np, trues_np)
        print('ITS - mse:{}, mae:{}'.format(ITSReason_mse, ITSReason_mae))

        f = open("result_long_term_forecast.txt", 'a')
        f.write(setting + "_" + 
            str(self.args.model) + "_" +
            str(self.args.data_path) + "_" +
            str(self.args.output_len) + "\n")
        
        f.write('ITS - mse:{}, mae:{}\n'.format(ITSReason_mse, ITSReason_mae))
        f.write('\n')
        f.close()
        
        return