import os
import random
import time
import warnings

from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from ...ITS.ITS import ITS
from utils.metrics import metric

warnings.filterwarnings('ignore')


class Exp_Imputation(Exp_Basic):
    def __init__(self, args):
        super(Exp_Imputation, self).__init__(args)

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader
    
    def enable_dropout_only(self, model):
        for m in model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train() 

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        # self.device = self.args.device

        print("info:", self.args.seq_len)
        print("loading model from {}".format(self.args.pretrain_model_path))
        print("loading backward_model from {}".format(self.args.backward_pretrain_model_path))

        if self.args.adaptation:
            load_item = torch.load(self.args.pretrain_model_path, map_location=self.device)
            self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False)

        self.model = self.model.to(self.device)

        its = ITS(
            args=self.args,
            backward_checkpoints=self.args.backward_pretrain_model_path,
            device=self.device
        )

        n_samples = self.args.n_samples

        torch.manual_seed(self.args.seed)
        torch.cuda.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)

        trues = []       
        masks = []          
        preds_ITSReason = []

        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(test_loader, desc="ITS", leave=True)):
                batch_x = batch_x.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)

                # random mask
                B, T, N = batch_x.shape
                assert T % self.args.patch_len == 0
                mask = torch.rand((B, T // self.args.patch_len, N)).to(self.device)
                mask = mask.unsqueeze(2).repeat(1, 1, self.args.patch_len, 1)
                mask[mask <= self.args.mask_rate] = 0  # masked
                mask[mask > self.args.mask_rate] = 1  # remained
                mask = mask.view(mask.size(0), -1, mask.size(-1))
                mask[:, :self.args.patch_len, :] = 1  # first patch is always observed
                inp = batch_x.masked_fill(mask == 0, 0)

                outputs = self.model(inp, batch_x_mark, None, None, mask)

                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, :, f_dim:]
                outputs = outputs.detach().cpu().numpy()
                preds.append(outputs)

                self.enable_dropout_only(self.model)
                    
                with torch.random.fork_rng(devices=[self.device] if self.device != 'cpu' else []):
                    torch.manual_seed(self.args.seed + 100 + i)
                    torch.cuda.manual_seed(self.args.seed + 100 + i) if self.device != 'cpu' else None

                    inp_expanded = inp.repeat(n_samples, 1, 1)
                    batch_x_mark_expanded = batch_x_mark.repeat(n_samples, 1, 1)
                    mask_expanded = mask.repeat(n_samples, 1, 1)

                    if self.args.use_amp:
                        with torch.amp.autocast('cuda'):
                            outputs_expanded = self.model(inp_expanded, batch_x_mark_expanded, None, None, mask_expanded)
                    else:
                        outputs_expanded = self.model(inp_expanded, batch_x_mark_expanded, None, None, mask_expanded)

                    step_samples_tensor = outputs_expanded.reshape(n_samples, B, T, N)
                    
                    fusion_candidate = its.run_inference(
                        args=self.args,
                        pred=step_samples_tensor,
                        batch_x=batch_x,
                        mask=mask,
                    )
                    
                    if isinstance(fusion_candidate, torch.Tensor):
                        fusion_candidate_f = fusion_candidate[:, :, f_dim:].detach().cpu().numpy()
                    else:
                        fusion_candidate_f = fusion_candidate[:, :, f_dim:]
                    preds_ITSReason.append(fusion_candidate_f)

                batch_x = batch_x[:, :, f_dim:]
                mask = mask[:, :, f_dim:]

                true = batch_x.detach().cpu().numpy()
                trues.append(true)
                masks.append(mask.detach().cpu().numpy())

        trues = np.concatenate(trues, 0)
        masks = np.concatenate(masks, 0)
        
        preds_ITSReason = np.concatenate(preds_ITSReason, 0)
        
        mae_back, mse_back, rmse_back, mape_back, mspe_back, _ = metric(preds_ITSReason[masks == 0], trues[masks == 0])
        print('ITS - mse:{}, mae:{}'.format(mse_back, mae_back))

        folder_path = './test_results/' + self.args.task_name + '/' + self.args.model + '/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        result = "result_imputation_" + self.args.data + ".txt"

        with open(result, 'a') as f:
            f.write(setting + "  \n")
           
            f.write('ITS - mse:{}, mae:{}\n'.format(mse_back, mae_back))
            
            f.write('\n')

        return
