import random
import time
from tqdm import tqdm
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from ...ITS.ITS import ITS
from utils.metrics import metric
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings('ignore')


class Exp_Imputation(Exp_Basic):
    def __init__(self, args):
        super(Exp_Imputation, self).__init__(args)

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader
    
    def enable_dropout_only(self, model):
        for m in model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train()
    
    def test(self, setting):
        test_data, test_loader = self._get_data(flag='test')

        print('loading model')
        checkpoint_path = os.path.join('./checkpoints/' + self.args.task_name + '/' + self.args.model + '/' + setting, 'checkpoint.pth')

        state_dict = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(self.device)

        its = ITS(
            args=self.args,
            backward_checkpoints=self.args.backward_pretrain_model_path,
            device=self.args.device
        )

        n_samples = self.args.n_samples 

        torch.manual_seed(self.args.seed)
        torch.cuda.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)
            
        trues = []   
        masks = []   
        preds_ITSReason = [] 
        folder_path = './test_results/' + self.args.task_name + '/' + self.args.model + '/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(test_loader, desc="ITS", leave=True)):
                batch_x = batch_x.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)

                # random mask
                B, T, N = batch_x.shape
                mask = torch.rand((B, T // self.args.impatch_len, N)).to(self.device)
                mask = mask.unsqueeze(2).repeat(1, 1, self.args.impatch_len, 1)
                mask[mask <= self.args.mask_rate] = 0
                mask[mask > self.args.mask_rate] = 1
                mask = mask.view(mask.size(0), -1, mask.size(-1))
                mask[:, :self.args.impatch_len, :] = 1
                inp = batch_x.masked_fill(mask == 0, 0)

                f_dim = -1 if self.args.features == 'MS' else 0

                self.enable_dropout_only(self.model)
                    
                with torch.random.fork_rng(devices=[self.device] if self.args.device != 'cpu' else []):

                    torch.manual_seed(self.args.seed + 100 + i)
                    torch.cuda.manual_seed(self.args.seed + 100 + i) if self.args.device != 'cpu' else None
                    
                    inp_expanded = inp.repeat(n_samples, 1, 1)
                    batch_x_mark_expanded = batch_x_mark.repeat(n_samples, 1, 1)
                    mask_expanded = mask.repeat(n_samples, 1, 1)

                    if self.args.use_amp:
                        with torch.amp.autocast('cuda'):
                            outputs_expanded = self.model(inp_expanded, batch_x_mark_expanded, None, None, mask_expanded)
                    else:
                        outputs_expanded = self.model(inp_expanded, batch_x_mark_expanded, None, None, mask_expanded)

                    step_samples_tensor = outputs_expanded.reshape(n_samples, B, T, N)
                    

                    fusion_candidate = reasoner.run_inference(
                        args=self.args,
                        pred=step_samples_tensor,
                        batch_x=batch_x,
                        mask=mask,
                    )
                    
                    if isinstance(fusion_candidate, torch.Tensor):
                        fusion_candidate_f = fusion_candidate[:, :, f_dim:].detach().cpu().numpy()
                    else:
                        fusion_candidate_f = fusion_candidate[:, :, f_dim:]
                    preds_ITSReason.append(fusion_candidate_f)

                batch_x = batch_x[:, :, f_dim:]
                mask = mask[:, :, f_dim:]

                true = batch_x.detach().cpu().numpy()
                trues.append(true)
                masks.append(mask.detach().cpu().numpy())

        trues = np.concatenate(trues, 0)
        masks = np.concatenate(masks, 0)
        preds_ITSReason = np.concatenate(preds_ITSReason, 0)
        
        mae_its, mse_its, rmse_its, mape_its, mspe_its = metric(preds_ITSReason[masks == 0], trues[masks == 0])
        print('ITS - mse:{}, mae:{}'.format(mse_its, mae_its))

        folder_path = './test_results/' + self.args.task_name + '/' + self.args.model + '/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        result = "result_imputation_" + self.args.data + ".txt"

        with open(result, 'a') as f:
            f.write(setting + "  \n")
            f.write('ITS - mse:{}, mae:{}\n'.format(mse_its, mae_its))
            f.write('\n')

        return
